warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,553 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import cached_property
17
+ from typing import Any, Optional, Set, Union
18
+
19
+ import warp as wp
20
+ import warp._src.fem.cache as cache
21
+ import warp._src.fem.utils as utils
22
+ from warp._src.codegen import Struct, StructInstance
23
+ from warp._src.context import Devicelike
24
+ from warp._src.fem.geometry import (
25
+ Element,
26
+ Geometry,
27
+ GeometryPartition,
28
+ WholeGeometryPartition,
29
+ )
30
+ from warp._src.fem.operator import Operator
31
+ from warp._src.fem.types import NULL_ELEMENT_INDEX, ElementKind
32
+
33
+ _wp_module_name_ = "warp.fem.domain"
34
+
35
+ GeometryOrPartition = Union[Geometry, GeometryPartition]
36
+
37
+
38
+ class GeometryDomain:
39
+ """Interface class for domains, i.e. (partial) views of elements in a Geometry"""
40
+
41
+ def __init__(self, geometry: GeometryOrPartition):
42
+ if isinstance(geometry, GeometryPartition):
43
+ self.geometry_partition = geometry
44
+ else:
45
+ self.geometry_partition = WholeGeometryPartition(geometry)
46
+
47
+ self.geometry = self.geometry_partition.geometry
48
+
49
+ @cached_property
50
+ def name(self) -> str:
51
+ return f"{self.geometry_partition.name}_{self.__class__.__name__}"
52
+
53
+ def __str__(self) -> str:
54
+ return self.name
55
+
56
+ def __eq__(self, other) -> bool:
57
+ return self.__class__ == other.__class__ and self.geometry_partition == other.geometry_partition
58
+
59
+ @property
60
+ def element_kind(self) -> ElementKind:
61
+ """Kind of elements that this domain contains (cells or sides)"""
62
+ raise NotImplementedError
63
+
64
+ @property
65
+ def dimension(self) -> int:
66
+ """Dimension of the elements of the domain"""
67
+ raise NotImplementedError
68
+
69
+ def element_count(self) -> int:
70
+ """Number of elements in the domain"""
71
+ raise NotImplementedError
72
+
73
+ def geometry_element_count(self) -> int:
74
+ """Number of elements in the underlying geometry"""
75
+ return self.geometry.cell_count()
76
+
77
+ def reference_element(self) -> Element:
78
+ """Type of reference element"""
79
+ raise NotImplementedError
80
+
81
+ def element_index_arg_value(self, device: Devicelike) -> StructInstance:
82
+ """Value of the argument to be passed to device functions"""
83
+ args = self.ElementIndexArg()
84
+ self.fill_element_index_arg(args, device)
85
+ return args
86
+
87
+ def fill_element_index_arg(self, arg: "GeometryDomain.ElementIndexArg", device: Devicelike):
88
+ arg.assign(self.element_index_arg_value(device))
89
+
90
+ def element_arg_value(self, device: Devicelike) -> StructInstance:
91
+ """Value of the argument to be passed to device functions"""
92
+ args = self.ElementArg()
93
+ self.fill_element_arg(args, device)
94
+ return args
95
+
96
+ def fill_element_arg(self, arg: "GeometryDomain.ElementArg", device: Devicelike):
97
+ arg.assign(self.element_arg_value(device))
98
+
99
+ ElementIndexArg: Struct
100
+ """Structure containing arguments to be passed to device functions computing element indices"""
101
+
102
+ element_index: wp.Function
103
+ """Device function for retrieving an ElementIndex from a linearized index"""
104
+
105
+ element_partition_index: wp.Function
106
+ """Device function for retrieving linearized index in the domain's partition from an ElementIndex"""
107
+
108
+ ElementArg: Struct
109
+ """Structure containing arguments to be passed to device functions computing element geometry"""
110
+
111
+ element_measure: wp.Function
112
+ """Device function returning the measure determinant (e.g. volume, area) at a given point"""
113
+
114
+ element_measure_ratio: wp.Function
115
+ """Device function returning the ratio of the measure of a side to that of its neighbour cells"""
116
+
117
+ element_position: wp.Function
118
+ """Device function returning the element position at a sample point"""
119
+
120
+ element_deformation_gradient: wp.Function
121
+ """Device function returning the gradient of the position with respect to the element's reference space"""
122
+
123
+ element_normal: wp.Function
124
+ """Device function returning the element normal at a sample point"""
125
+
126
+ element_closest_point: wp.Function
127
+ """Device function returning the coordinates of the closest point in a given element to a world position"""
128
+
129
+ element_coordinates: wp.Function
130
+ """Device function returning the coordinates corresponding to a world position in a given element reference system"""
131
+
132
+ element_lookup: wp.Function
133
+ """Device function returning the sample point in the domain's geometry corresponding to a world position"""
134
+
135
+ element_partition_lookup: wp.Function
136
+ """Device function returning the sample point in the domain's geometry partition corresponding to a world position"""
137
+
138
+ def notify_operator_usage(self, ops: Set[Operator]):
139
+ """Makes the Domain aware that the operators `ops` will be applied"""
140
+ pass
141
+
142
+ @cached_property
143
+ def DomainArg(self):
144
+ return self._make_domain_arg()
145
+
146
+ def _make_domain_arg(self):
147
+ @cache.dynamic_struct(suffix=self.name)
148
+ class DomainArg:
149
+ geo: self.ElementArg
150
+ index: self.ElementIndexArg
151
+
152
+ return DomainArg
153
+
154
+
155
+ class Cells(GeometryDomain):
156
+ """A Domain containing all cells of the geometry or geometry partition"""
157
+
158
+ def __init__(self, geometry: GeometryOrPartition):
159
+ super().__init__(geometry)
160
+
161
+ @property
162
+ def element_kind(self) -> ElementKind:
163
+ return ElementKind.CELL
164
+
165
+ @property
166
+ def dimension(self) -> int:
167
+ return self.geometry.dimension
168
+
169
+ def reference_element(self) -> Element:
170
+ return self.geometry.reference_cell()
171
+
172
+ def element_count(self) -> int:
173
+ return self.geometry_partition.cell_count()
174
+
175
+ def geometry_element_count(self) -> int:
176
+ return self.geometry.cell_count()
177
+
178
+ @property
179
+ def ElementIndexArg(self) -> Struct:
180
+ return self.geometry_partition.CellArg
181
+
182
+ def element_index_arg_value(self, device: Devicelike) -> StructInstance:
183
+ return self.geometry_partition.cell_arg_value(device)
184
+
185
+ def fill_element_index_arg(self, arg: ElementIndexArg, device: Devicelike):
186
+ self.geometry_partition.fill_cell_arg(arg, device)
187
+
188
+ @property
189
+ def element_index(self) -> wp.Function:
190
+ return self.geometry_partition.cell_index
191
+
192
+ @property
193
+ def element_partition_index(self) -> wp.Function:
194
+ return self.geometry_partition.partition_cell_index
195
+
196
+ def element_arg_value(self, device: Devicelike) -> StructInstance:
197
+ return self.geometry.cell_arg_value(device)
198
+
199
+ def fill_element_arg(self, arg: "ElementArg", device: Devicelike):
200
+ self.geometry.fill_cell_arg(arg, device)
201
+
202
+ @property
203
+ def ElementArg(self) -> Struct:
204
+ return self.geometry.CellArg
205
+
206
+ @property
207
+ def element_position(self) -> wp.Function:
208
+ return self.geometry.cell_position
209
+
210
+ @property
211
+ def element_deformation_gradient(self) -> wp.Function:
212
+ return self.geometry.cell_deformation_gradient
213
+
214
+ @property
215
+ def element_measure(self) -> wp.Function:
216
+ return self.geometry.cell_measure
217
+
218
+ @property
219
+ def element_measure_ratio(self) -> wp.Function:
220
+ return self.geometry.cell_measure_ratio
221
+
222
+ @property
223
+ def element_normal(self) -> wp.Function:
224
+ return self.geometry.cell_normal
225
+
226
+ @property
227
+ def element_closest_point(self) -> wp.Function:
228
+ return self.geometry.cell_closest_point
229
+
230
+ @property
231
+ def element_coordinates(self) -> wp.Function:
232
+ return self.geometry.cell_coordinates
233
+
234
+ @property
235
+ def element_lookup(self) -> wp.Function:
236
+ return self.geometry.cell_lookup
237
+
238
+ @cached_property
239
+ def element_partition_lookup(self) -> wp.Function:
240
+ pos_type = cache.cached_vec_type(self.geometry.dimension, dtype=float)
241
+
242
+ @cache.dynamic_func(suffix=self.geometry_partition.name)
243
+ def is_in_partition(args: self.ElementIndexArg, cell_index: int):
244
+ return self.geometry_partition.partition_cell_index(args, cell_index) != NULL_ELEMENT_INDEX
245
+
246
+ filtered_cell_lookup = self.geometry.make_filtered_cell_lookup(filter_func=is_in_partition)
247
+
248
+ # overloads
249
+ filter_target = True
250
+ pos_type = cache.cached_vec_type(self.geometry.dimension, dtype=float)
251
+
252
+ @cache.dynamic_func(suffix=self.name, allow_overloads=True)
253
+ def cell_partition_lookup(args: self.DomainArg, pos: pos_type, max_dist: float):
254
+ return filtered_cell_lookup(args.geo, pos, max_dist, args.index, filter_target)
255
+
256
+ @cache.dynamic_func(suffix=self.name, allow_overloads=True)
257
+ def cell_partition_lookup(args: self.DomainArg, pos: pos_type):
258
+ max_dist = 0.0
259
+ return filtered_cell_lookup(args.geo, pos, max_dist, args.index, filter_target)
260
+
261
+ return cell_partition_lookup
262
+
263
+ def supports_lookup(self, device):
264
+ return self.geometry.supports_cell_lookup(device)
265
+
266
+ @property
267
+ def domain_cell_arg(self) -> wp.Function:
268
+ return Cells._identity_fn
269
+
270
+ def cell_domain(self):
271
+ return self
272
+
273
+ @wp.func
274
+ def _identity_fn(x: Any):
275
+ return x
276
+
277
+
278
+ class Sides(GeometryDomain):
279
+ """A Domain containing all (interior and boundary) sides of the geometry or geometry partition"""
280
+
281
+ def __init__(self, geometry: GeometryOrPartition):
282
+ self.geometry = geometry
283
+ super().__init__(geometry)
284
+
285
+ self.element_lookup = None
286
+ self.element_partition_lookup = None
287
+ self.element_filtered_lookup = None
288
+
289
+ def supports_lookup(self, device):
290
+ return False
291
+
292
+ @property
293
+ def element_kind(self) -> ElementKind:
294
+ return ElementKind.SIDE
295
+
296
+ @property
297
+ def dimension(self) -> int:
298
+ return self.geometry.dimension - 1
299
+
300
+ def reference_element(self) -> Element:
301
+ return self.geometry.reference_side()
302
+
303
+ def element_count(self) -> int:
304
+ return self.geometry_partition.side_count()
305
+
306
+ def geometry_element_count(self) -> int:
307
+ return self.geometry.side_count()
308
+
309
+ @property
310
+ def ElementIndexArg(self) -> Struct:
311
+ return self.geometry_partition.SideArg
312
+
313
+ def element_index_arg_value(self, device: Devicelike) -> StructInstance:
314
+ return self.geometry_partition.side_arg_value(device)
315
+
316
+ def fill_element_index_arg(self, arg: "ElementIndexArg", device: Devicelike):
317
+ self.geometry_partition.fill_side_arg(arg, device)
318
+
319
+ @property
320
+ def element_index(self) -> wp.Function:
321
+ return self.geometry_partition.side_index
322
+
323
+ @property
324
+ def ElementArg(self) -> Struct:
325
+ return self.geometry.SideArg
326
+
327
+ def element_arg_value(self, device: Devicelike) -> StructInstance:
328
+ return self.geometry.side_arg_value(device)
329
+
330
+ def fill_element_arg(self, arg: "ElementArg", device: Devicelike):
331
+ self.geometry.fill_side_arg(arg, device)
332
+
333
+ @property
334
+ def element_position(self) -> wp.Function:
335
+ return self.geometry.side_position
336
+
337
+ @property
338
+ def element_deformation_gradient(self) -> wp.Function:
339
+ return self.geometry.side_deformation_gradient
340
+
341
+ @property
342
+ def element_measure(self) -> wp.Function:
343
+ return self.geometry.side_measure
344
+
345
+ @property
346
+ def element_measure_ratio(self) -> wp.Function:
347
+ return self.geometry.side_measure_ratio
348
+
349
+ @property
350
+ def element_normal(self) -> wp.Function:
351
+ return self.geometry.side_normal
352
+
353
+ @property
354
+ def element_closest_point(self) -> wp.Function:
355
+ return self.geometry.side_closest_point
356
+
357
+ @property
358
+ def element_coordinates(self) -> wp.Function:
359
+ return self.geometry.side_coordinates
360
+
361
+ @property
362
+ def element_inner_cell_index(self) -> wp.Function:
363
+ return self.geometry.side_inner_cell_index
364
+
365
+ @property
366
+ def element_outer_cell_index(self) -> wp.Function:
367
+ return self.geometry.side_outer_cell_index
368
+
369
+ @property
370
+ def element_inner_cell_coords(self) -> wp.Function:
371
+ return self.geometry.side_inner_cell_coords
372
+
373
+ @property
374
+ def element_outer_cell_coords(self) -> wp.Function:
375
+ return self.geometry.side_outer_cell_coords
376
+
377
+ @property
378
+ def cell_to_element_coords(self) -> wp.Function:
379
+ return self.geometry.side_from_cell_coords
380
+
381
+ @cached_property
382
+ def domain_cell_arg(self) -> wp.Function:
383
+ CellDomainArg = self.cell_domain().DomainArg
384
+
385
+ @cache.dynamic_func(suffix=self.name)
386
+ def domain_cell_arg(x: self.DomainArg):
387
+ return CellDomainArg(
388
+ self.geometry.side_to_cell_arg(x.geo),
389
+ self.geometry_partition.side_to_cell_arg(x.index),
390
+ )
391
+
392
+ return domain_cell_arg
393
+
394
+ def cell_domain(self):
395
+ return Cells(self.geometry_partition)
396
+
397
+
398
+ class BoundarySides(Sides):
399
+ """A Domain containing boundary sides of the geometry or geometry partition"""
400
+
401
+ def __init__(self, geometry: GeometryOrPartition):
402
+ super().__init__(geometry)
403
+
404
+ def element_count(self) -> int:
405
+ return self.geometry_partition.boundary_side_count()
406
+
407
+ def geometry_element_count(self) -> int:
408
+ return self.geometry.boundary_side_count()
409
+
410
+ @property
411
+ def element_index(self) -> wp.Function:
412
+ return self.geometry_partition.boundary_side_index
413
+
414
+
415
+ class FrontierSides(Sides):
416
+ """A Domain containing frontier sides of the geometry partition (sides shared with at least another partition)"""
417
+
418
+ def __init__(self, geometry: GeometryOrPartition):
419
+ super().__init__(geometry)
420
+
421
+ def element_count(self) -> int:
422
+ return self.geometry_partition.frontier_side_count()
423
+
424
+ def geometry_element_count(self) -> int:
425
+ raise RuntimeError("Frontier sides not defined at the geometry level")
426
+
427
+ @property
428
+ def element_index(self) -> wp.Function:
429
+ return self.geometry_partition.frontier_side_index
430
+
431
+
432
+ class Subdomain(GeometryDomain):
433
+ """Subdomain -- restriction of domain to a subset of its elements"""
434
+
435
+ def __init__(
436
+ self,
437
+ domain: GeometryDomain,
438
+ element_mask: Optional[wp.array] = None,
439
+ element_indices: Optional[wp.array] = None,
440
+ temporary_store: Optional[cache.TemporaryStore] = None,
441
+ ):
442
+ """
443
+ Create a subdomain from a subset of elements.
444
+
445
+ Exactly one of `element_mask` and `element_indices` should be provided.
446
+
447
+ Args:
448
+ domain: the containing domain
449
+ element_mask: Array of length ``domain.element_count()`` indicating which elements should be included. Array values must be either ``1`` (selected) or ``0`` (not selected).
450
+ element_indices: Explicit array of element indices to include
451
+ """
452
+
453
+ super().__init__(domain.geometry_partition)
454
+
455
+ if element_indices is None:
456
+ if element_mask is None:
457
+ raise ValueError("Either 'element_mask' or 'element_indices' should be provided")
458
+ element_indices, _ = utils.masked_indices(mask=element_mask, temporary_store=temporary_store)
459
+ element_indices = element_indices.detach()
460
+ elif element_mask is not None:
461
+ raise ValueError("Only one of 'element_mask' and 'element_indices' should be provided")
462
+
463
+ self._domain = domain
464
+ self._element_indices = element_indices
465
+ self.ElementIndexArg = self._make_element_index_arg()
466
+ self.element_index = self._make_element_index()
467
+
468
+ # forward
469
+ self.ElementArg = self._domain.ElementArg
470
+ self.geometry_element_count = self._domain.geometry_element_count
471
+ self.reference_element = self._domain.reference_element
472
+ self.element_arg_value = self._domain.element_arg_value
473
+ self.fill_element_arg = self._domain.fill_element_arg
474
+ self.element_measure = self._domain.element_measure
475
+ self.element_measure_ratio = self._domain.element_measure_ratio
476
+ self.element_position = self._domain.element_position
477
+ self.element_deformation_gradient = self._domain.element_deformation_gradient
478
+ self.element_lookup = self._domain.element_lookup
479
+ self.element_partition_lookup = self._domain.element_partition_lookup
480
+ self.element_normal = self._domain.element_normal
481
+
482
+ @property
483
+ def name(self) -> str:
484
+ return f"{self._domain.name}_Subdomain"
485
+
486
+ def __eq__(self, other) -> bool:
487
+ return (
488
+ self.__class__ == other.__class__
489
+ and self.geometry_partition == other.geometry_partition
490
+ and self._element_indices == other._element_indices
491
+ )
492
+
493
+ @property
494
+ def element_kind(self) -> ElementKind:
495
+ return self._domain.element_kind
496
+
497
+ @property
498
+ def dimension(self) -> int:
499
+ return self._domain.dimension
500
+
501
+ def element_count(self) -> int:
502
+ return self._element_indices.shape[0]
503
+
504
+ def _make_element_index_arg(self):
505
+ @cache.dynamic_struct(suffix=self.name)
506
+ class ElementIndexArg:
507
+ domain_arg: self._domain.ElementIndexArg
508
+ element_indices: wp.array(dtype=int)
509
+
510
+ return ElementIndexArg
511
+
512
+ @cache.cached_arg_value
513
+ def element_index_arg_value(self, device: Devicelike):
514
+ arg = self.ElementIndexArg()
515
+ self.fill_element_index_arg(arg, device)
516
+ return arg
517
+
518
+ def fill_element_index_arg(self, arg: "GeometryDomain.ElementIndexArg", device: Devicelike):
519
+ self._domain.fill_element_index_arg(arg.domain_arg, device)
520
+ arg.element_indices = self._element_indices.to(device)
521
+
522
+ def _make_element_index(self) -> wp.Function:
523
+ @cache.dynamic_func(suffix=self.name)
524
+ def element_index(arg: self.ElementIndexArg, index: int):
525
+ return self._domain.element_index(arg.domain_arg, arg.element_indices[index])
526
+
527
+ return element_index
528
+
529
+ def _make_element_partition_index(self) -> wp.Function:
530
+ @cache.dynamic_func(suffix=self.name)
531
+ def element_partition_index(arg: self.ElementIndexArg, element_index: int):
532
+ return self._domain.element_partition_index(arg.domain_arg, element_index)
533
+
534
+ return element_partition_index
535
+
536
+ def supports_lookup(self, device):
537
+ return self._domain.supports_lokup(device)
538
+
539
+ def cell_domain(self):
540
+ return self._domain.cell_domain()
541
+
542
+ @cached_property
543
+ def domain_cell_arg(self) -> wp.Function:
544
+ CellDomainArg = self.cell_domain().DomainArg
545
+
546
+ @cache.dynamic_func(suffix=self.name)
547
+ def domain_cell_arg(x: self.DomainArg):
548
+ return CellDomainArg(
549
+ self.geometry.side_to_cell_arg(x.geo),
550
+ self.geometry_partition.side_to_cell_arg(x.index.domain_arg),
551
+ )
552
+
553
+ return domain_cell_arg
@@ -0,0 +1,131 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional, Union
17
+
18
+ from warp._src.fem.domain import Cells, GeometryDomain
19
+ from warp._src.fem.space import (
20
+ FunctionSpace,
21
+ SpacePartition,
22
+ SpaceRestriction,
23
+ make_space_partition,
24
+ make_space_restriction,
25
+ )
26
+
27
+ from .field import DiscreteField, FieldLike, GeometryField, ImplicitField, NonconformingField, SpaceField, UniformField
28
+ from .nodal_field import NodalField
29
+ from .restriction import FieldRestriction
30
+ from .virtual import LocalTestField, LocalTrialField, TestField, TrialField
31
+
32
+
33
+ def make_restriction(
34
+ field: DiscreteField,
35
+ space_restriction: Optional[SpaceRestriction] = None,
36
+ domain: Optional[GeometryDomain] = None,
37
+ device=None,
38
+ ) -> FieldRestriction:
39
+ """
40
+ Restricts a discrete field to a subset of elements.
41
+
42
+ Args:
43
+ field: the discrete field to restrict
44
+ space_restriction: the function space restriction defining the subset of elements to consider
45
+ domain: if ``space_restriction`` is not provided, the :py:class:`Domain` defining the subset of elements to consider
46
+ device: Warp device on which to perform and store computations
47
+
48
+ Returns:
49
+ the field restriction
50
+ """
51
+
52
+ if space_restriction is None:
53
+ space_restriction = make_space_restriction(space_partition=field.space_partition, domain=domain, device=device)
54
+
55
+ return FieldRestriction(field=field, space_restriction=space_restriction)
56
+
57
+
58
+ def make_test(
59
+ space: FunctionSpace,
60
+ space_restriction: Optional[SpaceRestriction] = None,
61
+ space_partition: Optional[SpacePartition] = None,
62
+ domain: Optional[GeometryDomain] = None,
63
+ device=None,
64
+ ) -> TestField:
65
+ """
66
+ Constructs a test field over a function space or its restriction
67
+
68
+ Args:
69
+ space: the function space
70
+ space_restriction: restriction of the space topology to a domain
71
+ space_partition: if `space_restriction` is ``None``, the optional subset of node indices to consider
72
+ domain: if `space_restriction` is ``None``, optional subset of elements to consider
73
+ device: Warp device on which to perform and store computations
74
+
75
+ Returns:
76
+ the test field
77
+ """
78
+
79
+ if space_restriction is None:
80
+ space_restriction = make_space_restriction(
81
+ space_topology=space.topology, space_partition=space_partition, domain=domain, device=device
82
+ )
83
+
84
+ return TestField(space_restriction=space_restriction, space=space)
85
+
86
+
87
+ def make_trial(
88
+ space: FunctionSpace,
89
+ space_restriction: Optional[SpaceRestriction] = None,
90
+ space_partition: Optional[SpacePartition] = None,
91
+ domain: Optional[GeometryDomain] = None,
92
+ ) -> TrialField:
93
+ """
94
+ Constructs a trial field over a function space or partition
95
+
96
+ Args:
97
+ space: the function space or function space restriction
98
+ space_restriction: restriction of the space topology to a domain
99
+ space_partition: if `space_restriction` is ``None``, the optional subset of node indices to consider
100
+ domain: if `space_restriction` is ``None``, optional subset of elements to consider
101
+ device: Warp device on which to perform and store computations
102
+
103
+ Returns:
104
+ the trial field
105
+ """
106
+
107
+ if space_restriction is not None:
108
+ domain = space_restriction.domain
109
+ space_partition = space_restriction.space_partition
110
+
111
+ if space_partition is None:
112
+ if domain is None:
113
+ domain = Cells(geometry=space.geometry)
114
+ space_partition = make_space_partition(
115
+ space_topology=space.topology, geometry_partition=domain.geometry_partition
116
+ )
117
+ elif domain is None:
118
+ domain = Cells(geometry=space_partition.geo_partition)
119
+
120
+ return TrialField(space, space_partition, domain)
121
+
122
+
123
+ def make_discrete_field(
124
+ space: FunctionSpace,
125
+ space_partition: Optional[SpacePartition] = None,
126
+ ) -> DiscreteField:
127
+ """Constructs a zero-initialized discrete field over a function space or partition
128
+
129
+ See also: :meth:`warp.fem.FunctionSpace.make_field`
130
+ """
131
+ return space.make_field(space_partition=space_partition)