warp-lang 1.6.1__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +21 -7
- warp/autograd.py +14 -6
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +424 -6
- warp/build_dll.py +20 -20
- warp/builtins.py +467 -368
- warp/codegen.py +193 -125
- warp/config.py +56 -12
- warp/constants.py +14 -6
- warp/context.py +524 -277
- warp/dlpack.py +22 -12
- warp/examples/__init__.py +14 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_api.py +14 -6
- warp/examples/benchmarks/benchmark_cloth.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
- warp/examples/benchmarks/benchmark_gemm.py +82 -48
- warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
- warp/examples/benchmarks/benchmark_launches.py +14 -6
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +14 -6
- warp/examples/core/example_cupy.py +14 -6
- warp/examples/core/example_dem.py +14 -6
- warp/examples/core/example_fluid.py +14 -6
- warp/examples/core/example_graph_capture.py +14 -6
- warp/examples/core/example_marching_cubes.py +14 -6
- warp/examples/core/example_mesh.py +14 -6
- warp/examples/core/example_mesh_intersect.py +14 -6
- warp/examples/core/example_nvdb.py +14 -6
- warp/examples/core/example_raycast.py +14 -6
- warp/examples/core/example_raymarch.py +14 -6
- warp/examples/core/example_render_opengl.py +14 -6
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +14 -6
- warp/examples/core/example_torch.py +14 -6
- warp/examples/core/example_wave.py +14 -6
- warp/examples/fem/example_adaptive_grid.py +14 -6
- warp/examples/fem/example_apic_fluid.py +15 -7
- warp/examples/fem/example_burgers.py +16 -8
- warp/examples/fem/example_convection_diffusion.py +14 -6
- warp/examples/fem/example_convection_diffusion_dg.py +14 -6
- warp/examples/fem/example_deformed_geometry.py +15 -7
- warp/examples/fem/example_diffusion.py +14 -6
- warp/examples/fem/example_diffusion_3d.py +14 -6
- warp/examples/fem/example_diffusion_mgpu.py +14 -6
- warp/examples/fem/example_distortion_energy.py +15 -7
- warp/examples/fem/example_magnetostatics.py +20 -12
- warp/examples/fem/example_mixed_elasticity.py +14 -6
- warp/examples/fem/example_navier_stokes.py +14 -6
- warp/examples/fem/example_nonconforming_contact.py +14 -6
- warp/examples/fem/example_stokes.py +14 -6
- warp/examples/fem/example_stokes_transfer.py +14 -6
- warp/examples/fem/example_streamlines.py +14 -6
- warp/examples/fem/utils.py +24 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +14 -6
- warp/examples/optim/example_cloth_throw.py +14 -6
- warp/examples/optim/example_diffray.py +14 -6
- warp/examples/optim/example_drone.py +14 -6
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +14 -6
- warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
- warp/examples/optim/example_softbody_properties.py +14 -6
- warp/examples/optim/example_spring_cage.py +14 -6
- warp/examples/optim/example_trajectory.py +14 -6
- warp/examples/sim/example_cartpole.py +14 -6
- warp/examples/sim/example_cloth.py +14 -6
- warp/examples/sim/example_cloth_self_contact.py +14 -6
- warp/examples/sim/example_granular.py +14 -6
- warp/examples/sim/example_granular_collision_sdf.py +14 -6
- warp/examples/sim/example_jacobian_ik.py +14 -6
- warp/examples/sim/example_particle_chain.py +14 -6
- warp/examples/sim/example_quadruped.py +14 -6
- warp/examples/sim/example_rigid_chain.py +14 -6
- warp/examples/sim/example_rigid_contact.py +14 -6
- warp/examples/sim/example_rigid_force.py +14 -6
- warp/examples/sim/example_rigid_gyroscopic.py +14 -6
- warp/examples/sim/example_rigid_soft_contact.py +14 -6
- warp/examples/sim/example_soft_body.py +14 -6
- warp/examples/tile/example_tile_cholesky.py +14 -6
- warp/examples/tile/example_tile_convolution.py +14 -6
- warp/examples/tile/example_tile_fft.py +14 -6
- warp/examples/tile/example_tile_filtering.py +14 -6
- warp/examples/tile/example_tile_matmul.py +16 -10
- warp/examples/tile/example_tile_mlp.py +14 -6
- warp/examples/tile/example_tile_nbody.py +14 -6
- warp/examples/tile/example_tile_walker.py +14 -6
- warp/fabric.py +15 -0
- warp/fem/__init__.py +26 -1
- warp/fem/adaptivity.py +19 -4
- warp/fem/cache.py +15 -0
- warp/fem/dirichlet.py +15 -0
- warp/fem/domain.py +15 -0
- warp/fem/field/__init__.py +15 -0
- warp/fem/field/field.py +15 -0
- warp/fem/field/nodal_field.py +37 -68
- warp/fem/field/restriction.py +15 -0
- warp/fem/field/virtual.py +77 -23
- warp/fem/geometry/__init__.py +15 -0
- warp/fem/geometry/adaptive_nanogrid.py +24 -10
- warp/fem/geometry/closest_point.py +16 -1
- warp/fem/geometry/deformed_geometry.py +20 -2
- warp/fem/geometry/element.py +15 -0
- warp/fem/geometry/geometry.py +20 -0
- warp/fem/geometry/grid_2d.py +27 -12
- warp/fem/geometry/grid_3d.py +27 -15
- warp/fem/geometry/hexmesh.py +20 -7
- warp/fem/geometry/nanogrid.py +24 -11
- warp/fem/geometry/partition.py +15 -0
- warp/fem/geometry/quadmesh.py +28 -13
- warp/fem/geometry/tetmesh.py +18 -4
- warp/fem/geometry/trimesh.py +18 -8
- warp/fem/integrate.py +277 -93
- warp/fem/linalg.py +20 -5
- warp/fem/operator.py +15 -0
- warp/fem/polynomial.py +15 -0
- warp/fem/quadrature/__init__.py +15 -0
- warp/fem/quadrature/pic_quadrature.py +52 -22
- warp/fem/quadrature/quadrature.py +209 -25
- warp/fem/space/__init__.py +16 -1
- warp/fem/space/basis_function_space.py +19 -2
- warp/fem/space/basis_space.py +40 -18
- warp/fem/space/dof_mapper.py +15 -0
- warp/fem/space/function_space.py +15 -0
- warp/fem/space/grid_2d_function_space.py +15 -0
- warp/fem/space/grid_3d_function_space.py +15 -0
- warp/fem/space/hexmesh_function_space.py +17 -2
- warp/fem/space/nanogrid_function_space.py +15 -0
- warp/fem/space/partition.py +21 -2
- warp/fem/space/quadmesh_function_space.py +23 -8
- warp/fem/space/restriction.py +15 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +38 -23
- warp/fem/space/shape/shape_function.py +15 -0
- warp/fem/space/shape/square_shape_function.py +27 -12
- warp/fem/space/shape/tet_shape_function.py +15 -0
- warp/fem/space/shape/triangle_shape_function.py +16 -1
- warp/fem/space/tetmesh_function_space.py +18 -3
- warp/fem/space/topology.py +15 -0
- warp/fem/space/trimesh_function_space.py +17 -2
- warp/fem/types.py +15 -0
- warp/fem/utils.py +27 -6
- warp/jax.py +28 -7
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +103 -6
- warp/native/array.h +28 -6
- warp/native/builtin.h +44 -9
- warp/native/bvh.cpp +18 -7
- warp/native/bvh.cu +57 -20
- warp/native/bvh.h +17 -7
- warp/native/clang/clang.cpp +45 -9
- warp/native/coloring.cpp +15 -6
- warp/native/crt.cpp +15 -6
- warp/native/crt.h +15 -6
- warp/native/cuda_crt.h +15 -6
- warp/native/cuda_util.cpp +29 -6
- warp/native/cuda_util.h +17 -6
- warp/native/error.cpp +15 -6
- warp/native/error.h +15 -6
- warp/native/exports.h +85 -63
- warp/native/fabric.h +15 -6
- warp/native/hashgrid.cpp +15 -6
- warp/native/hashgrid.cu +15 -6
- warp/native/hashgrid.h +15 -6
- warp/native/initializer_array.h +15 -6
- warp/native/intersect.h +41 -32
- warp/native/intersect_adj.h +48 -39
- warp/native/intersect_tri.h +17 -0
- warp/native/marching.cpp +16 -0
- warp/native/marching.cu +16 -7
- warp/native/marching.h +17 -0
- warp/native/mat.h +528 -15
- warp/native/mathdx.cpp +15 -6
- warp/native/matnn.h +15 -6
- warp/native/mesh.cpp +15 -6
- warp/native/mesh.cu +15 -6
- warp/native/mesh.h +25 -16
- warp/native/noise.h +15 -6
- warp/native/quat.h +114 -17
- warp/native/rand.h +21 -6
- warp/native/range.h +15 -6
- warp/native/reduce.cpp +15 -6
- warp/native/reduce.cu +15 -6
- warp/native/runlength_encode.cpp +15 -6
- warp/native/runlength_encode.cu +15 -6
- warp/native/scan.cpp +15 -6
- warp/native/scan.cu +15 -6
- warp/native/scan.h +15 -6
- warp/native/solid_angle.h +17 -0
- warp/native/sort.cpp +137 -65
- warp/native/sort.cu +167 -21
- warp/native/sort.h +23 -7
- warp/native/sparse.cpp +58 -28
- warp/native/sparse.cu +67 -23
- warp/native/spatial.h +15 -6
- warp/native/svd.h +131 -6
- warp/native/temp_buffer.h +15 -6
- warp/native/tile.h +316 -111
- warp/native/tile_reduce.h +61 -9
- warp/native/vec.h +83 -13
- warp/native/volume.cpp +100 -119
- warp/native/volume.cu +15 -6
- warp/native/volume.h +15 -6
- warp/native/volume_builder.cu +40 -16
- warp/native/volume_builder.h +21 -6
- warp/native/volume_impl.h +15 -6
- warp/native/warp.cpp +20 -12
- warp/native/warp.cu +114 -16
- warp/native/warp.h +34 -16
- warp/optim/__init__.py +14 -6
- warp/optim/adam.py +14 -6
- warp/optim/linear.py +25 -10
- warp/optim/sgd.py +14 -6
- warp/paddle.py +14 -6
- warp/render/__init__.py +14 -6
- warp/render/render_opengl.py +14 -6
- warp/render/render_usd.py +14 -6
- warp/render/utils.py +14 -6
- warp/sim/__init__.py +14 -7
- warp/sim/articulation.py +18 -10
- warp/sim/collide.py +35 -16
- warp/sim/graph_coloring.py +14 -6
- warp/sim/import_mjcf.py +463 -162
- warp/sim/import_snu.py +14 -7
- warp/sim/import_urdf.py +46 -18
- warp/sim/import_usd.py +14 -7
- warp/sim/inertia.py +14 -6
- warp/sim/integrator.py +14 -6
- warp/sim/integrator_euler.py +19 -11
- warp/sim/integrator_featherstone.py +17 -16
- warp/sim/integrator_vbd.py +222 -8
- warp/sim/integrator_xpbd.py +19 -11
- warp/sim/model.py +56 -19
- warp/sim/particles.py +14 -6
- warp/sim/render.py +14 -6
- warp/sim/utils.py +17 -2
- warp/sparse.py +657 -555
- warp/stubs.py +231 -19
- warp/tape.py +14 -6
- warp/tests/aux_test_class_kernel.py +14 -6
- warp/tests/aux_test_compile_consts_dummy.py +14 -6
- warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
- warp/tests/aux_test_dependent.py +14 -6
- warp/tests/aux_test_grad_customs.py +14 -6
- warp/tests/aux_test_instancing_gc.py +14 -6
- warp/tests/aux_test_module_unload.py +14 -6
- warp/tests/aux_test_name_clash1.py +14 -6
- warp/tests/aux_test_name_clash2.py +14 -6
- warp/tests/aux_test_unresolved_func.py +14 -6
- warp/tests/aux_test_unresolved_symbol.py +14 -6
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_async.py → cuda/test_async.py} +14 -6
- warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
- warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
- warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
- warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
- warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
- warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
- warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
- warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
- warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
- warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
- warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
- warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
- warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
- warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
- warp/tests/run_coverage_serial.py +14 -6
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
- warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
- warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
- warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
- warp/tests/{test_model.py → sim/test_model.py} +55 -7
- warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +14 -6
- warp/tests/test_arithmetic.py +14 -6
- warp/tests/test_array.py +14 -6
- warp/tests/test_array_reduce.py +14 -6
- warp/tests/test_assert.py +14 -6
- warp/tests/test_atomic.py +14 -6
- warp/tests/test_bool.py +15 -7
- warp/tests/test_builtins_resolution.py +14 -6
- warp/tests/test_closest_point_edge_edge.py +14 -6
- warp/tests/test_codegen.py +14 -6
- warp/tests/test_codegen_instancing.py +14 -6
- warp/tests/test_compile_consts.py +14 -6
- warp/tests/test_conditional.py +14 -6
- warp/tests/test_context.py +14 -6
- warp/tests/test_copy.py +14 -6
- warp/tests/test_ctypes.py +14 -6
- warp/tests/test_dense.py +14 -6
- warp/tests/test_devices.py +14 -6
- warp/tests/test_examples.py +42 -42
- warp/tests/test_fabricarray.py +14 -6
- warp/tests/test_fast_math.py +14 -6
- warp/tests/test_fem.py +37 -10
- warp/tests/test_fp16.py +14 -6
- warp/tests/test_func.py +14 -6
- warp/tests/test_future_annotations.py +14 -6
- warp/tests/test_generics.py +14 -6
- warp/tests/test_grad.py +14 -6
- warp/tests/test_grad_customs.py +14 -6
- warp/tests/test_grad_debug.py +14 -6
- warp/tests/test_implicit_init.py +14 -6
- warp/tests/test_import.py +14 -6
- warp/tests/test_indexedarray.py +14 -6
- warp/tests/test_intersect.py +14 -6
- warp/tests/test_iter.py +14 -6
- warp/tests/test_large.py +14 -6
- warp/tests/test_launch.py +14 -6
- warp/tests/test_lerp.py +14 -6
- warp/tests/test_linear_solvers.py +15 -11
- warp/tests/test_lvalue.py +14 -6
- warp/tests/test_mat.py +247 -85
- warp/tests/test_mat_lite.py +14 -6
- warp/tests/test_mat_scalar_ops.py +18 -10
- warp/tests/test_math.py +14 -6
- warp/tests/test_mlp.py +14 -6
- warp/tests/test_module_hashing.py +14 -6
- warp/tests/test_modules_lite.py +14 -6
- warp/tests/test_noise.py +14 -6
- warp/tests/test_operators.py +14 -6
- warp/tests/test_options.py +14 -6
- warp/tests/test_overwrite.py +15 -60
- warp/tests/test_print.py +14 -6
- warp/tests/test_quat.py +81 -52
- warp/tests/test_rand.py +58 -43
- warp/tests/test_reload.py +14 -6
- warp/tests/test_rounding.py +14 -6
- warp/tests/test_runlength_encode.py +14 -6
- warp/tests/test_scalar_ops.py +14 -6
- warp/tests/test_smoothstep.py +14 -6
- warp/tests/test_snippet.py +15 -0
- warp/tests/test_sparse.py +61 -12
- warp/tests/test_spatial.py +89 -6
- warp/tests/test_special_values.py +14 -6
- warp/tests/test_static.py +15 -7
- warp/tests/test_struct.py +14 -6
- warp/tests/test_tape.py +14 -6
- warp/tests/test_transient_module.py +14 -6
- warp/tests/test_triangle_closest_point.py +14 -6
- warp/tests/test_types.py +14 -6
- warp/tests/test_utils.py +98 -10
- warp/tests/test_vec.py +60 -40
- warp/tests/test_vec_lite.py +14 -6
- warp/tests/test_vec_scalar_ops.py +14 -6
- warp/tests/test_verify_fp.py +14 -6
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
- warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
- warp/tests/unittest_serial.py +15 -6
- warp/tests/unittest_suites.py +59 -65
- warp/tests/unittest_utils.py +16 -7
- warp/tests/walkthrough_debug.py +14 -6
- warp/thirdparty/unittest_parallel.py +15 -8
- warp/torch.py +14 -6
- warp/types.py +124 -664
- warp/utils.py +151 -78
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp/examples/optim/example_walker.py +0 -309
- warp/native/cutlass_gemm.cpp +0 -34
- warp/native/cutlass_gemm.cu +0 -373
- warp/tests/test_matmul.py +0 -503
- warp/tests/test_matmul_lite.py +0 -403
- warp/tests/test_vbd.py +0 -378
- warp/tests/unused_test_misc.py +0 -69
- warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
- warp_lang-1.6.1.dist-info/RECORD +0 -419
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2022 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
###########################################################################
|
|
9
17
|
# Example Distortion Energy
|
|
@@ -74,7 +82,7 @@ def boundary_projector_form(
|
|
|
74
82
|
):
|
|
75
83
|
# Fix a single point
|
|
76
84
|
# (underconstrained, solution up to a rotation in UV space)
|
|
77
|
-
w = wp.
|
|
85
|
+
w = wp.where(s.qp_index == 0, 1.0, 0.0)
|
|
78
86
|
return w * wp.dot(u(s), v(s))
|
|
79
87
|
|
|
80
88
|
|
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2024 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
###########################################################################
|
|
9
17
|
# Example Magnetostatics
|
|
@@ -52,8 +60,8 @@ def cube_to_cylinder_grad(x: wp.vec3):
|
|
|
52
60
|
dir_grad = (wp.identity(n=3, dtype=float) - wp.outer(dir_xz, dir_xz)) / wp.length(pos_xz)
|
|
53
61
|
|
|
54
62
|
abs_xz = wp.abs(pos_xz)
|
|
55
|
-
xinf_grad = wp.
|
|
56
|
-
abs_xz[0] > abs_xz[2], wp.
|
|
63
|
+
xinf_grad = wp.where(
|
|
64
|
+
abs_xz[0] > abs_xz[2], wp.vec(wp.sign(pos_xz[0]), 0.0, 0.0), wp.vec3(0.0, 0.0, wp.sign(pos_xz[2]))
|
|
57
65
|
)
|
|
58
66
|
grad = dir_grad * wp.max(abs_xz) + wp.outer(dir_xz, xinf_grad)
|
|
59
67
|
|
|
@@ -77,10 +85,10 @@ def permeability_field(
|
|
|
77
85
|
r = wp.sqrt(x * x + z * z)
|
|
78
86
|
|
|
79
87
|
if r <= core_radius:
|
|
80
|
-
return wp.
|
|
88
|
+
return wp.where(y < core_height, MU_i, MU_0)
|
|
81
89
|
|
|
82
90
|
if r >= coil_internal_radius and r <= coil_external_radius:
|
|
83
|
-
return wp.
|
|
91
|
+
return wp.where(y < coil_height, MU_c, MU_0)
|
|
84
92
|
|
|
85
93
|
return MU_0
|
|
86
94
|
|
|
@@ -99,10 +107,10 @@ def current_field(
|
|
|
99
107
|
|
|
100
108
|
r = wp.sqrt(x * x + z * z)
|
|
101
109
|
|
|
102
|
-
return wp.
|
|
110
|
+
return wp.where(
|
|
103
111
|
y < coil_height and r >= coil_internal_radius and r <= coil_external_radius,
|
|
104
|
-
wp.vec3(0.0),
|
|
105
112
|
wp.vec3(z, 0.0, -x) * current / r,
|
|
113
|
+
wp.vec3(0.0),
|
|
106
114
|
)
|
|
107
115
|
|
|
108
116
|
|
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2022 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
###########################################################################
|
|
9
17
|
# Example Mixed Elasticity
|
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2022 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
###########################################################################
|
|
9
17
|
# Example Navier Stokes
|
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2024 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
###########################################################################
|
|
9
17
|
# Example Nonconforming Contact
|
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2022 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
###########################################################################
|
|
9
17
|
# Example Stokes
|
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2022 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
###########################################################################
|
|
9
17
|
# Example Stokes Transfer
|
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2024 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
###########################################################################
|
|
9
17
|
# Example Streamlines
|
warp/examples/fem/utils.py
CHANGED
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Any, Dict, Optional, Tuple
|
|
2
17
|
|
|
3
18
|
import numpy as np
|
|
@@ -19,6 +34,9 @@ __all__ = [
|
|
|
19
34
|
"Plot",
|
|
20
35
|
]
|
|
21
36
|
|
|
37
|
+
# matrix inversion routines contain nested loops,
|
|
38
|
+
# default unrolling leads to code explosion
|
|
39
|
+
wp.set_module_options({"max_unroll": 6})
|
|
22
40
|
|
|
23
41
|
#
|
|
24
42
|
# Mesh utilities
|
|
@@ -210,6 +228,7 @@ def bsr_cg(
|
|
|
210
228
|
mv_routine=None,
|
|
211
229
|
quiet=False,
|
|
212
230
|
method: str = "cg",
|
|
231
|
+
M: BsrMatrix = None,
|
|
213
232
|
) -> Tuple[float, int]:
|
|
214
233
|
"""Solves the linear system A x = b using an iterative solver, optionally with diagonal preconditioning
|
|
215
234
|
|
|
@@ -230,7 +249,9 @@ def bsr_cg(
|
|
|
230
249
|
|
|
231
250
|
"""
|
|
232
251
|
|
|
233
|
-
if
|
|
252
|
+
if M is not None:
|
|
253
|
+
M = aslinearoperator(M)
|
|
254
|
+
elif mv_routine is None:
|
|
234
255
|
M = preconditioner(A, "diag") if use_diag_precond else None
|
|
235
256
|
else:
|
|
236
257
|
A = LinearOperator(A.shape, A.dtype, A.device, matvec=mv_routine)
|
|
@@ -443,7 +464,7 @@ def bsr_solve_saddle(
|
|
|
443
464
|
return err, end_iter
|
|
444
465
|
|
|
445
466
|
|
|
446
|
-
@wp.kernel
|
|
467
|
+
@wp.kernel(enable_backward=False)
|
|
447
468
|
def _compute_schur_inverse_diagonal(
|
|
448
469
|
B_offsets: wp.array(dtype=int),
|
|
449
470
|
B_indices: wp.array(dtype=int),
|
|
@@ -485,7 +506,7 @@ def invert_diagonal_bsr_matrix(A: BsrMatrix):
|
|
|
485
506
|
)
|
|
486
507
|
|
|
487
508
|
|
|
488
|
-
@wp.kernel
|
|
509
|
+
@wp.kernel(enable_backward=False)
|
|
489
510
|
def _block_diagonal_invert(values: wp.array(dtype=Any)):
|
|
490
511
|
i = wp.tid()
|
|
491
512
|
values[i] = fem.utils.inverse_qr(values[i])
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
###########################################################################
|
|
17
|
+
# Example jax_callable()
|
|
18
|
+
#
|
|
19
|
+
# Examples of calling annotated Python functions from JAX.
|
|
20
|
+
###########################################################################
|
|
21
|
+
|
|
22
|
+
from functools import partial
|
|
23
|
+
|
|
24
|
+
import jax
|
|
25
|
+
import jax.numpy as jnp
|
|
26
|
+
|
|
27
|
+
import warp as wp
|
|
28
|
+
from warp.jax_experimental.ffi import jax_callable
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@wp.kernel
|
|
32
|
+
def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
|
|
33
|
+
tid = wp.tid()
|
|
34
|
+
output[tid] = a[tid] * s
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@wp.kernel
|
|
38
|
+
def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
|
|
39
|
+
tid = wp.tid()
|
|
40
|
+
output[tid] = a[tid] * s
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# The Python function to call.
|
|
44
|
+
# Note the argument annotations, just like Warp kernels.
|
|
45
|
+
def example_func(
|
|
46
|
+
# inputs
|
|
47
|
+
a: wp.array(dtype=float),
|
|
48
|
+
b: wp.array(dtype=wp.vec2),
|
|
49
|
+
s: float,
|
|
50
|
+
# outputs
|
|
51
|
+
c: wp.array(dtype=float),
|
|
52
|
+
d: wp.array(dtype=wp.vec2),
|
|
53
|
+
):
|
|
54
|
+
wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
|
|
55
|
+
wp.launch(scale_vec_kernel, dim=b.shape, inputs=[b, s], outputs=[d])
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def example1():
|
|
59
|
+
jax_func = jax_callable(example_func, num_outputs=2, vmap_method="broadcast_all")
|
|
60
|
+
|
|
61
|
+
@jax.jit
|
|
62
|
+
def f():
|
|
63
|
+
# inputs
|
|
64
|
+
a = jnp.arange(10, dtype=jnp.float32)
|
|
65
|
+
b = jnp.arange(10, dtype=jnp.float32).reshape((5, 2)) # wp.vec2
|
|
66
|
+
s = 2.0
|
|
67
|
+
|
|
68
|
+
# output shapes
|
|
69
|
+
output_dims = {"c": a.shape, "d": b.shape}
|
|
70
|
+
|
|
71
|
+
c, d = jax_func(a, b, s, output_dims=output_dims)
|
|
72
|
+
|
|
73
|
+
return c, d
|
|
74
|
+
|
|
75
|
+
r1, r2 = f()
|
|
76
|
+
print(r1)
|
|
77
|
+
print(r2)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def example2():
|
|
81
|
+
jax_func = jax_callable(example_func, num_outputs=2, vmap_method="broadcast_all")
|
|
82
|
+
|
|
83
|
+
# NOTE: scalar arguments must be static compile-time constants
|
|
84
|
+
@partial(jax.jit, static_argnames=["s"])
|
|
85
|
+
def f(a, b, s):
|
|
86
|
+
# output shapes
|
|
87
|
+
output_dims = {"c": a.shape, "d": b.shape}
|
|
88
|
+
|
|
89
|
+
c, d = jax_func(a, b, s, output_dims=output_dims)
|
|
90
|
+
|
|
91
|
+
return c, d
|
|
92
|
+
|
|
93
|
+
# inputs
|
|
94
|
+
a = jnp.arange(10, dtype=jnp.float32)
|
|
95
|
+
b = jnp.arange(10, dtype=jnp.float32).reshape((5, 2)) # wp.vec2
|
|
96
|
+
s = 3.0
|
|
97
|
+
|
|
98
|
+
r1, r2 = f(a, b, s)
|
|
99
|
+
print(r1)
|
|
100
|
+
print(r2)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def main():
|
|
104
|
+
wp.init()
|
|
105
|
+
wp.load_module(device=wp.get_device())
|
|
106
|
+
|
|
107
|
+
examples = [example1, example2]
|
|
108
|
+
|
|
109
|
+
for example in examples:
|
|
110
|
+
print("\n===========================================================================")
|
|
111
|
+
print(f"{example.__name__}:")
|
|
112
|
+
example()
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
if __name__ == "__main__":
|
|
116
|
+
main()
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
###########################################################################
|
|
17
|
+
# Example register_ffi_callback()
|
|
18
|
+
#
|
|
19
|
+
# Examples of calling Python functions from JAX.
|
|
20
|
+
# Target functions must have the form func(inputs, outputs, attrs, ctx).
|
|
21
|
+
###########################################################################
|
|
22
|
+
|
|
23
|
+
import jax
|
|
24
|
+
import jax.numpy as jnp
|
|
25
|
+
import numpy as np
|
|
26
|
+
|
|
27
|
+
import warp as wp
|
|
28
|
+
from warp.jax import get_jax_device
|
|
29
|
+
from warp.jax_experimental.ffi import register_ffi_callback
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@wp.kernel
|
|
33
|
+
def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
|
|
34
|
+
tid = wp.tid()
|
|
35
|
+
output[tid] = a[tid] * s
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@wp.kernel
|
|
39
|
+
def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
|
|
40
|
+
tid = wp.tid()
|
|
41
|
+
output[tid] = a[tid] * s
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def example1():
|
|
45
|
+
# the Python function to call
|
|
46
|
+
def print_args(inputs, outputs, attrs, ctx):
|
|
47
|
+
def buffer_to_string(b):
|
|
48
|
+
return str(b.dtype) + str(list(b.shape)) + " @%x" % b.data
|
|
49
|
+
|
|
50
|
+
print("Inputs: ", ", ".join([buffer_to_string(b) for b in inputs]))
|
|
51
|
+
print("Outputs: ", ", ".join([buffer_to_string(b) for b in outputs]))
|
|
52
|
+
print("Attributes: ", "".join(["\n %s: %s" % (k, str(v)) for k, v in attrs.items()]))
|
|
53
|
+
|
|
54
|
+
# register callback
|
|
55
|
+
register_ffi_callback("print_args", print_args)
|
|
56
|
+
|
|
57
|
+
# set up call
|
|
58
|
+
call = jax.ffi.ffi_call("print_args", jax.ShapeDtypeStruct((1, 2, 3), jnp.int8))
|
|
59
|
+
|
|
60
|
+
# call it
|
|
61
|
+
call(
|
|
62
|
+
jnp.arange(16),
|
|
63
|
+
jnp.arange(32.0).reshape((4, 8)),
|
|
64
|
+
str_attr="hi",
|
|
65
|
+
f32_attr=np.float32(4.2),
|
|
66
|
+
dict_attr={"a": 1, "b": 6.4},
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def example2():
|
|
71
|
+
# the Python function to call
|
|
72
|
+
def warp_func(inputs, outputs, attrs, ctx):
|
|
73
|
+
# input arrays
|
|
74
|
+
a = inputs[0]
|
|
75
|
+
b = inputs[1]
|
|
76
|
+
|
|
77
|
+
# scalar attributes
|
|
78
|
+
s = attrs["scale"]
|
|
79
|
+
|
|
80
|
+
# output arrays
|
|
81
|
+
c = outputs[0]
|
|
82
|
+
d = outputs[1]
|
|
83
|
+
|
|
84
|
+
device = wp.device_from_jax(get_jax_device())
|
|
85
|
+
stream = wp.Stream(device, cuda_stream=ctx.stream)
|
|
86
|
+
|
|
87
|
+
with wp.ScopedStream(stream):
|
|
88
|
+
# launch with arrays of scalars
|
|
89
|
+
wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
|
|
90
|
+
|
|
91
|
+
# launch with arrays of vec2
|
|
92
|
+
# NOTE: the input shapes are from JAX arrays, we need to strip the inner dimension for vec2 arrays
|
|
93
|
+
wp.launch(scale_vec_kernel, dim=b.shape[0], inputs=[b, s], outputs=[d])
|
|
94
|
+
|
|
95
|
+
# register callback
|
|
96
|
+
register_ffi_callback("warp_func", warp_func)
|
|
97
|
+
|
|
98
|
+
n = 10
|
|
99
|
+
|
|
100
|
+
# inputs
|
|
101
|
+
a = jnp.arange(n, dtype=jnp.float32)
|
|
102
|
+
b = jnp.arange(n, dtype=jnp.float32).reshape((n // 2, 2)) # array of wp.vec2
|
|
103
|
+
s = 2.0
|
|
104
|
+
|
|
105
|
+
# set up call
|
|
106
|
+
out_types = [
|
|
107
|
+
jax.ShapeDtypeStruct(a.shape, jnp.float32),
|
|
108
|
+
jax.ShapeDtypeStruct(b.shape, jnp.float32), # array of wp.vec2
|
|
109
|
+
]
|
|
110
|
+
call = jax.ffi.ffi_call("warp_func", out_types)
|
|
111
|
+
|
|
112
|
+
# call it
|
|
113
|
+
c, d = call(a, b, scale=s)
|
|
114
|
+
|
|
115
|
+
print(c)
|
|
116
|
+
print(d)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def main():
|
|
120
|
+
wp.init()
|
|
121
|
+
wp.load_module(device=wp.get_device())
|
|
122
|
+
|
|
123
|
+
examples = [example1, example2]
|
|
124
|
+
|
|
125
|
+
for example in examples:
|
|
126
|
+
print("\n===========================================================================")
|
|
127
|
+
print(f"{example.__name__}:")
|
|
128
|
+
example()
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
if __name__ == "__main__":
|
|
132
|
+
main()
|