warp-lang 1.6.1__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +21 -7
- warp/autograd.py +14 -6
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +424 -6
- warp/build_dll.py +20 -20
- warp/builtins.py +467 -368
- warp/codegen.py +193 -125
- warp/config.py +56 -12
- warp/constants.py +14 -6
- warp/context.py +524 -277
- warp/dlpack.py +22 -12
- warp/examples/__init__.py +14 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_api.py +14 -6
- warp/examples/benchmarks/benchmark_cloth.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
- warp/examples/benchmarks/benchmark_gemm.py +82 -48
- warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
- warp/examples/benchmarks/benchmark_launches.py +14 -6
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +14 -6
- warp/examples/core/example_cupy.py +14 -6
- warp/examples/core/example_dem.py +14 -6
- warp/examples/core/example_fluid.py +14 -6
- warp/examples/core/example_graph_capture.py +14 -6
- warp/examples/core/example_marching_cubes.py +14 -6
- warp/examples/core/example_mesh.py +14 -6
- warp/examples/core/example_mesh_intersect.py +14 -6
- warp/examples/core/example_nvdb.py +14 -6
- warp/examples/core/example_raycast.py +14 -6
- warp/examples/core/example_raymarch.py +14 -6
- warp/examples/core/example_render_opengl.py +14 -6
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +14 -6
- warp/examples/core/example_torch.py +14 -6
- warp/examples/core/example_wave.py +14 -6
- warp/examples/fem/example_adaptive_grid.py +14 -6
- warp/examples/fem/example_apic_fluid.py +15 -7
- warp/examples/fem/example_burgers.py +16 -8
- warp/examples/fem/example_convection_diffusion.py +14 -6
- warp/examples/fem/example_convection_diffusion_dg.py +14 -6
- warp/examples/fem/example_deformed_geometry.py +15 -7
- warp/examples/fem/example_diffusion.py +14 -6
- warp/examples/fem/example_diffusion_3d.py +14 -6
- warp/examples/fem/example_diffusion_mgpu.py +14 -6
- warp/examples/fem/example_distortion_energy.py +15 -7
- warp/examples/fem/example_magnetostatics.py +20 -12
- warp/examples/fem/example_mixed_elasticity.py +14 -6
- warp/examples/fem/example_navier_stokes.py +14 -6
- warp/examples/fem/example_nonconforming_contact.py +14 -6
- warp/examples/fem/example_stokes.py +14 -6
- warp/examples/fem/example_stokes_transfer.py +14 -6
- warp/examples/fem/example_streamlines.py +14 -6
- warp/examples/fem/utils.py +24 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +14 -6
- warp/examples/optim/example_cloth_throw.py +14 -6
- warp/examples/optim/example_diffray.py +14 -6
- warp/examples/optim/example_drone.py +14 -6
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +14 -6
- warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
- warp/examples/optim/example_softbody_properties.py +14 -6
- warp/examples/optim/example_spring_cage.py +14 -6
- warp/examples/optim/example_trajectory.py +14 -6
- warp/examples/sim/example_cartpole.py +14 -6
- warp/examples/sim/example_cloth.py +14 -6
- warp/examples/sim/example_cloth_self_contact.py +14 -6
- warp/examples/sim/example_granular.py +14 -6
- warp/examples/sim/example_granular_collision_sdf.py +14 -6
- warp/examples/sim/example_jacobian_ik.py +14 -6
- warp/examples/sim/example_particle_chain.py +14 -6
- warp/examples/sim/example_quadruped.py +14 -6
- warp/examples/sim/example_rigid_chain.py +14 -6
- warp/examples/sim/example_rigid_contact.py +14 -6
- warp/examples/sim/example_rigid_force.py +14 -6
- warp/examples/sim/example_rigid_gyroscopic.py +14 -6
- warp/examples/sim/example_rigid_soft_contact.py +14 -6
- warp/examples/sim/example_soft_body.py +14 -6
- warp/examples/tile/example_tile_cholesky.py +14 -6
- warp/examples/tile/example_tile_convolution.py +14 -6
- warp/examples/tile/example_tile_fft.py +14 -6
- warp/examples/tile/example_tile_filtering.py +14 -6
- warp/examples/tile/example_tile_matmul.py +16 -10
- warp/examples/tile/example_tile_mlp.py +14 -6
- warp/examples/tile/example_tile_nbody.py +14 -6
- warp/examples/tile/example_tile_walker.py +14 -6
- warp/fabric.py +15 -0
- warp/fem/__init__.py +26 -1
- warp/fem/adaptivity.py +19 -4
- warp/fem/cache.py +15 -0
- warp/fem/dirichlet.py +15 -0
- warp/fem/domain.py +15 -0
- warp/fem/field/__init__.py +15 -0
- warp/fem/field/field.py +15 -0
- warp/fem/field/nodal_field.py +37 -68
- warp/fem/field/restriction.py +15 -0
- warp/fem/field/virtual.py +77 -23
- warp/fem/geometry/__init__.py +15 -0
- warp/fem/geometry/adaptive_nanogrid.py +24 -10
- warp/fem/geometry/closest_point.py +16 -1
- warp/fem/geometry/deformed_geometry.py +20 -2
- warp/fem/geometry/element.py +15 -0
- warp/fem/geometry/geometry.py +20 -0
- warp/fem/geometry/grid_2d.py +27 -12
- warp/fem/geometry/grid_3d.py +27 -15
- warp/fem/geometry/hexmesh.py +20 -7
- warp/fem/geometry/nanogrid.py +24 -11
- warp/fem/geometry/partition.py +15 -0
- warp/fem/geometry/quadmesh.py +28 -13
- warp/fem/geometry/tetmesh.py +18 -4
- warp/fem/geometry/trimesh.py +18 -8
- warp/fem/integrate.py +277 -93
- warp/fem/linalg.py +20 -5
- warp/fem/operator.py +15 -0
- warp/fem/polynomial.py +15 -0
- warp/fem/quadrature/__init__.py +15 -0
- warp/fem/quadrature/pic_quadrature.py +52 -22
- warp/fem/quadrature/quadrature.py +209 -25
- warp/fem/space/__init__.py +16 -1
- warp/fem/space/basis_function_space.py +19 -2
- warp/fem/space/basis_space.py +40 -18
- warp/fem/space/dof_mapper.py +15 -0
- warp/fem/space/function_space.py +15 -0
- warp/fem/space/grid_2d_function_space.py +15 -0
- warp/fem/space/grid_3d_function_space.py +15 -0
- warp/fem/space/hexmesh_function_space.py +17 -2
- warp/fem/space/nanogrid_function_space.py +15 -0
- warp/fem/space/partition.py +21 -2
- warp/fem/space/quadmesh_function_space.py +23 -8
- warp/fem/space/restriction.py +15 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +38 -23
- warp/fem/space/shape/shape_function.py +15 -0
- warp/fem/space/shape/square_shape_function.py +27 -12
- warp/fem/space/shape/tet_shape_function.py +15 -0
- warp/fem/space/shape/triangle_shape_function.py +16 -1
- warp/fem/space/tetmesh_function_space.py +18 -3
- warp/fem/space/topology.py +15 -0
- warp/fem/space/trimesh_function_space.py +17 -2
- warp/fem/types.py +15 -0
- warp/fem/utils.py +27 -6
- warp/jax.py +28 -7
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +103 -6
- warp/native/array.h +28 -6
- warp/native/builtin.h +44 -9
- warp/native/bvh.cpp +18 -7
- warp/native/bvh.cu +57 -20
- warp/native/bvh.h +17 -7
- warp/native/clang/clang.cpp +45 -9
- warp/native/coloring.cpp +15 -6
- warp/native/crt.cpp +15 -6
- warp/native/crt.h +15 -6
- warp/native/cuda_crt.h +15 -6
- warp/native/cuda_util.cpp +29 -6
- warp/native/cuda_util.h +17 -6
- warp/native/error.cpp +15 -6
- warp/native/error.h +15 -6
- warp/native/exports.h +85 -63
- warp/native/fabric.h +15 -6
- warp/native/hashgrid.cpp +15 -6
- warp/native/hashgrid.cu +15 -6
- warp/native/hashgrid.h +15 -6
- warp/native/initializer_array.h +15 -6
- warp/native/intersect.h +41 -32
- warp/native/intersect_adj.h +48 -39
- warp/native/intersect_tri.h +17 -0
- warp/native/marching.cpp +16 -0
- warp/native/marching.cu +16 -7
- warp/native/marching.h +17 -0
- warp/native/mat.h +528 -15
- warp/native/mathdx.cpp +15 -6
- warp/native/matnn.h +15 -6
- warp/native/mesh.cpp +15 -6
- warp/native/mesh.cu +15 -6
- warp/native/mesh.h +25 -16
- warp/native/noise.h +15 -6
- warp/native/quat.h +114 -17
- warp/native/rand.h +21 -6
- warp/native/range.h +15 -6
- warp/native/reduce.cpp +15 -6
- warp/native/reduce.cu +15 -6
- warp/native/runlength_encode.cpp +15 -6
- warp/native/runlength_encode.cu +15 -6
- warp/native/scan.cpp +15 -6
- warp/native/scan.cu +15 -6
- warp/native/scan.h +15 -6
- warp/native/solid_angle.h +17 -0
- warp/native/sort.cpp +137 -65
- warp/native/sort.cu +167 -21
- warp/native/sort.h +23 -7
- warp/native/sparse.cpp +58 -28
- warp/native/sparse.cu +67 -23
- warp/native/spatial.h +15 -6
- warp/native/svd.h +131 -6
- warp/native/temp_buffer.h +15 -6
- warp/native/tile.h +316 -111
- warp/native/tile_reduce.h +61 -9
- warp/native/vec.h +83 -13
- warp/native/volume.cpp +100 -119
- warp/native/volume.cu +15 -6
- warp/native/volume.h +15 -6
- warp/native/volume_builder.cu +40 -16
- warp/native/volume_builder.h +21 -6
- warp/native/volume_impl.h +15 -6
- warp/native/warp.cpp +20 -12
- warp/native/warp.cu +114 -16
- warp/native/warp.h +34 -16
- warp/optim/__init__.py +14 -6
- warp/optim/adam.py +14 -6
- warp/optim/linear.py +25 -10
- warp/optim/sgd.py +14 -6
- warp/paddle.py +14 -6
- warp/render/__init__.py +14 -6
- warp/render/render_opengl.py +14 -6
- warp/render/render_usd.py +14 -6
- warp/render/utils.py +14 -6
- warp/sim/__init__.py +14 -7
- warp/sim/articulation.py +18 -10
- warp/sim/collide.py +35 -16
- warp/sim/graph_coloring.py +14 -6
- warp/sim/import_mjcf.py +463 -162
- warp/sim/import_snu.py +14 -7
- warp/sim/import_urdf.py +46 -18
- warp/sim/import_usd.py +14 -7
- warp/sim/inertia.py +14 -6
- warp/sim/integrator.py +14 -6
- warp/sim/integrator_euler.py +19 -11
- warp/sim/integrator_featherstone.py +17 -16
- warp/sim/integrator_vbd.py +222 -8
- warp/sim/integrator_xpbd.py +19 -11
- warp/sim/model.py +56 -19
- warp/sim/particles.py +14 -6
- warp/sim/render.py +14 -6
- warp/sim/utils.py +17 -2
- warp/sparse.py +657 -555
- warp/stubs.py +231 -19
- warp/tape.py +14 -6
- warp/tests/aux_test_class_kernel.py +14 -6
- warp/tests/aux_test_compile_consts_dummy.py +14 -6
- warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
- warp/tests/aux_test_dependent.py +14 -6
- warp/tests/aux_test_grad_customs.py +14 -6
- warp/tests/aux_test_instancing_gc.py +14 -6
- warp/tests/aux_test_module_unload.py +14 -6
- warp/tests/aux_test_name_clash1.py +14 -6
- warp/tests/aux_test_name_clash2.py +14 -6
- warp/tests/aux_test_unresolved_func.py +14 -6
- warp/tests/aux_test_unresolved_symbol.py +14 -6
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_async.py → cuda/test_async.py} +14 -6
- warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
- warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
- warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
- warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
- warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
- warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
- warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
- warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
- warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
- warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
- warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
- warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
- warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
- warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
- warp/tests/run_coverage_serial.py +14 -6
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
- warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
- warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
- warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
- warp/tests/{test_model.py → sim/test_model.py} +55 -7
- warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +14 -6
- warp/tests/test_arithmetic.py +14 -6
- warp/tests/test_array.py +14 -6
- warp/tests/test_array_reduce.py +14 -6
- warp/tests/test_assert.py +14 -6
- warp/tests/test_atomic.py +14 -6
- warp/tests/test_bool.py +15 -7
- warp/tests/test_builtins_resolution.py +14 -6
- warp/tests/test_closest_point_edge_edge.py +14 -6
- warp/tests/test_codegen.py +14 -6
- warp/tests/test_codegen_instancing.py +14 -6
- warp/tests/test_compile_consts.py +14 -6
- warp/tests/test_conditional.py +14 -6
- warp/tests/test_context.py +14 -6
- warp/tests/test_copy.py +14 -6
- warp/tests/test_ctypes.py +14 -6
- warp/tests/test_dense.py +14 -6
- warp/tests/test_devices.py +14 -6
- warp/tests/test_examples.py +42 -42
- warp/tests/test_fabricarray.py +14 -6
- warp/tests/test_fast_math.py +14 -6
- warp/tests/test_fem.py +37 -10
- warp/tests/test_fp16.py +14 -6
- warp/tests/test_func.py +14 -6
- warp/tests/test_future_annotations.py +14 -6
- warp/tests/test_generics.py +14 -6
- warp/tests/test_grad.py +14 -6
- warp/tests/test_grad_customs.py +14 -6
- warp/tests/test_grad_debug.py +14 -6
- warp/tests/test_implicit_init.py +14 -6
- warp/tests/test_import.py +14 -6
- warp/tests/test_indexedarray.py +14 -6
- warp/tests/test_intersect.py +14 -6
- warp/tests/test_iter.py +14 -6
- warp/tests/test_large.py +14 -6
- warp/tests/test_launch.py +14 -6
- warp/tests/test_lerp.py +14 -6
- warp/tests/test_linear_solvers.py +15 -11
- warp/tests/test_lvalue.py +14 -6
- warp/tests/test_mat.py +247 -85
- warp/tests/test_mat_lite.py +14 -6
- warp/tests/test_mat_scalar_ops.py +18 -10
- warp/tests/test_math.py +14 -6
- warp/tests/test_mlp.py +14 -6
- warp/tests/test_module_hashing.py +14 -6
- warp/tests/test_modules_lite.py +14 -6
- warp/tests/test_noise.py +14 -6
- warp/tests/test_operators.py +14 -6
- warp/tests/test_options.py +14 -6
- warp/tests/test_overwrite.py +15 -60
- warp/tests/test_print.py +14 -6
- warp/tests/test_quat.py +81 -52
- warp/tests/test_rand.py +58 -43
- warp/tests/test_reload.py +14 -6
- warp/tests/test_rounding.py +14 -6
- warp/tests/test_runlength_encode.py +14 -6
- warp/tests/test_scalar_ops.py +14 -6
- warp/tests/test_smoothstep.py +14 -6
- warp/tests/test_snippet.py +15 -0
- warp/tests/test_sparse.py +61 -12
- warp/tests/test_spatial.py +89 -6
- warp/tests/test_special_values.py +14 -6
- warp/tests/test_static.py +15 -7
- warp/tests/test_struct.py +14 -6
- warp/tests/test_tape.py +14 -6
- warp/tests/test_transient_module.py +14 -6
- warp/tests/test_triangle_closest_point.py +14 -6
- warp/tests/test_types.py +14 -6
- warp/tests/test_utils.py +98 -10
- warp/tests/test_vec.py +60 -40
- warp/tests/test_vec_lite.py +14 -6
- warp/tests/test_vec_scalar_ops.py +14 -6
- warp/tests/test_verify_fp.py +14 -6
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
- warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
- warp/tests/unittest_serial.py +15 -6
- warp/tests/unittest_suites.py +59 -65
- warp/tests/unittest_utils.py +16 -7
- warp/tests/walkthrough_debug.py +14 -6
- warp/thirdparty/unittest_parallel.py +15 -8
- warp/torch.py +14 -6
- warp/types.py +124 -664
- warp/utils.py +151 -78
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp/examples/optim/example_walker.py +0 -309
- warp/native/cutlass_gemm.cpp +0 -34
- warp/native/cutlass_gemm.cu +0 -373
- warp/tests/test_matmul.py +0 -503
- warp/tests/test_matmul_lite.py +0 -403
- warp/tests/test_vbd.py +0 -378
- warp/tests/unused_test_misc.py +0 -69
- warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
- warp_lang-1.6.1.dist-info/RECORD +0 -419
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import math
|
|
2
17
|
|
|
3
18
|
import numpy as np
|
|
@@ -446,8 +461,8 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
|
|
|
446
461
|
node_i, node_j = self._node_lobatto_indices(node_type, type_instance, type_index)
|
|
447
462
|
|
|
448
463
|
if node_type == SquareSerendipityShapeFunctions.VERTEX:
|
|
449
|
-
cx = wp.
|
|
450
|
-
cy = wp.
|
|
464
|
+
cx = wp.where(node_i == 0, 1.0 - coords[0], coords[0])
|
|
465
|
+
cy = wp.where(node_j == 0, 1.0 - coords[1], coords[1])
|
|
451
466
|
|
|
452
467
|
w = cx * cy
|
|
453
468
|
|
|
@@ -460,7 +475,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
|
|
|
460
475
|
|
|
461
476
|
w = float(1.0)
|
|
462
477
|
if node_type == SquareSerendipityShapeFunctions.EDGE_Y:
|
|
463
|
-
w *= wp.
|
|
478
|
+
w *= wp.where(node_i == 0, 1.0 - coords[0], coords[0])
|
|
464
479
|
else:
|
|
465
480
|
for k in range(ORDER_PLUS_ONE):
|
|
466
481
|
if k != node_i:
|
|
@@ -469,7 +484,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
|
|
|
469
484
|
w *= LAGRANGE_SCALE[node_i]
|
|
470
485
|
|
|
471
486
|
if node_type == SquareSerendipityShapeFunctions.EDGE_X:
|
|
472
|
-
w *= wp.
|
|
487
|
+
w *= wp.where(node_j == 0, 1.0 - coords[1], coords[1])
|
|
473
488
|
else:
|
|
474
489
|
for k in range(ORDER_PLUS_ONE):
|
|
475
490
|
if k != node_j:
|
|
@@ -498,11 +513,11 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
|
|
|
498
513
|
node_i, node_j = self._node_lobatto_indices(node_type, type_instance, type_index)
|
|
499
514
|
|
|
500
515
|
if node_type == SquareSerendipityShapeFunctions.VERTEX:
|
|
501
|
-
cx = wp.
|
|
502
|
-
cy = wp.
|
|
516
|
+
cx = wp.where(node_i == 0, 1.0 - coords[0], coords[0])
|
|
517
|
+
cy = wp.where(node_j == 0, 1.0 - coords[1], coords[1])
|
|
503
518
|
|
|
504
|
-
gx = wp.
|
|
505
|
-
gy = wp.
|
|
519
|
+
gx = wp.where(node_i == 0, -1.0, 1.0)
|
|
520
|
+
gy = wp.where(node_j == 0, -1.0, 1.0)
|
|
506
521
|
|
|
507
522
|
if ORDER == 2:
|
|
508
523
|
w = cx + cy - 2.0 + LOBATTO_COORDS[1]
|
|
@@ -522,7 +537,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
|
|
|
522
537
|
return wp.vec2(grad_x, grad_y) * DEGREE_3_CIRCLE_SCALE
|
|
523
538
|
|
|
524
539
|
if node_type == SquareSerendipityShapeFunctions.EDGE_X:
|
|
525
|
-
prefix_x = wp.
|
|
540
|
+
prefix_x = wp.where(node_j == 0, 1.0 - coords[1], coords[1])
|
|
526
541
|
else:
|
|
527
542
|
prefix_x = LAGRANGE_SCALE[node_j]
|
|
528
543
|
for k in range(ORDER_PLUS_ONE):
|
|
@@ -530,7 +545,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
|
|
|
530
545
|
prefix_x *= coords[1] - LOBATTO_COORDS[k]
|
|
531
546
|
|
|
532
547
|
if node_type == SquareSerendipityShapeFunctions.EDGE_Y:
|
|
533
|
-
prefix_y = wp.
|
|
548
|
+
prefix_y = wp.where(node_i == 0, 1.0 - coords[0], coords[0])
|
|
534
549
|
else:
|
|
535
550
|
prefix_y = LAGRANGE_SCALE[node_i]
|
|
536
551
|
for k in range(ORDER_PLUS_ONE):
|
|
@@ -538,7 +553,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
|
|
|
538
553
|
prefix_y *= coords[0] - LOBATTO_COORDS[k]
|
|
539
554
|
|
|
540
555
|
if node_type == SquareSerendipityShapeFunctions.EDGE_X:
|
|
541
|
-
grad_y = wp.
|
|
556
|
+
grad_y = wp.where(node_j == 0, -1.0, 1.0) * prefix_y
|
|
542
557
|
else:
|
|
543
558
|
prefix_y *= LAGRANGE_SCALE[node_j]
|
|
544
559
|
grad_y = float(0.0)
|
|
@@ -549,7 +564,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
|
|
|
549
564
|
prefix_y *= delta_y
|
|
550
565
|
|
|
551
566
|
if node_type == SquareSerendipityShapeFunctions.EDGE_Y:
|
|
552
|
-
grad_x = wp.
|
|
567
|
+
grad_x = wp.where(node_i == 0, -1.0, 1.0) * prefix_x
|
|
553
568
|
else:
|
|
554
569
|
prefix_x *= LAGRANGE_SCALE[node_i]
|
|
555
570
|
grad_x = float(0.0)
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import numpy as np
|
|
2
17
|
|
|
3
18
|
import warp as wp
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import numpy as np
|
|
2
17
|
|
|
3
18
|
import warp as wp
|
|
@@ -181,7 +196,7 @@ class TrianglePolynomialShapeFunctions(TriangleShapeFunction):
|
|
|
181
196
|
def trace_node_quadrature_weight(node_index_in_element: int):
|
|
182
197
|
node_type, type_index = self.node_type_and_type_index(node_index_in_element)
|
|
183
198
|
|
|
184
|
-
return wp.
|
|
199
|
+
return wp.where(node_type == TrianglePolynomialShapeFunctions.VERTEX, VERTEX_WEIGHT, EDGE_WEIGHT)
|
|
185
200
|
|
|
186
201
|
return trace_node_quadrature_weight
|
|
187
202
|
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import warp as wp
|
|
2
17
|
from warp.fem import cache
|
|
3
18
|
from warp.fem.geometry import Tetmesh
|
|
@@ -229,10 +244,10 @@ class TetmeshSpaceTopology(SpaceTopology):
|
|
|
229
244
|
edge = type_index // INTERIOR_NODES_PER_EDGE
|
|
230
245
|
c1, c2 = TetrahedronShapeFunction.edge_vidx(edge)
|
|
231
246
|
|
|
232
|
-
return wp.
|
|
247
|
+
return wp.where(
|
|
233
248
|
geo_arg.tet_vertex_indices[element_index][c1] > geo_arg.tet_vertex_indices[element_index][c2],
|
|
234
|
-
1.0,
|
|
235
249
|
-1.0,
|
|
250
|
+
1.0,
|
|
236
251
|
)
|
|
237
252
|
|
|
238
253
|
if wp.static(INTERIOR_NODES_PER_FACE > 0):
|
|
@@ -242,7 +257,7 @@ class TetmeshSpaceTopology(SpaceTopology):
|
|
|
242
257
|
global_face_index = topo_arg.tet_face_indices[element_index][face]
|
|
243
258
|
inner = topo_arg.face_tet_indices[global_face_index][0]
|
|
244
259
|
|
|
245
|
-
return wp.
|
|
260
|
+
return wp.where(inner == element_index, 1.0, -1.0)
|
|
246
261
|
|
|
247
262
|
return 1.0
|
|
248
263
|
|
warp/fem/space/topology.py
CHANGED
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Optional, Tuple, Type
|
|
2
17
|
|
|
3
18
|
import warp as wp
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import warp as wp
|
|
2
17
|
from warp.fem import cache
|
|
3
18
|
from warp.fem.geometry import Trimesh
|
|
@@ -160,11 +175,11 @@ class TrimeshSpaceTopology(SpaceTopology):
|
|
|
160
175
|
edge = type_index // INTERIOR_NODES_PER_SIDE
|
|
161
176
|
|
|
162
177
|
global_edge_index = topo_arg.tri_edge_indices[element_index][edge]
|
|
163
|
-
return wp.
|
|
178
|
+
return wp.where(
|
|
164
179
|
topo_arg.edge_vertex_indices[global_edge_index][0]
|
|
165
180
|
== geo_arg.topology.tri_vertex_indices[element_index][edge],
|
|
166
|
-
-1.0,
|
|
167
181
|
1.0,
|
|
182
|
+
-1.0,
|
|
168
183
|
)
|
|
169
184
|
|
|
170
185
|
return 1.0
|
warp/fem/types.py
CHANGED
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from enum import Enum
|
|
2
17
|
|
|
3
18
|
import warp as wp
|
warp/fem/utils.py
CHANGED
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Tuple, Union
|
|
2
17
|
|
|
3
18
|
import numpy as np
|
|
@@ -41,13 +56,11 @@ def compress_node_indices(
|
|
|
41
56
|
sorted_node_indices = sorted_node_indices_temp.array
|
|
42
57
|
sorted_array_indices = sorted_array_indices_temp.array
|
|
43
58
|
|
|
44
|
-
wp.copy(dest=sorted_node_indices, src=node_indices, count=index_count)
|
|
45
|
-
|
|
46
59
|
indices_per_element = 1 if node_indices.ndim == 1 else node_indices.shape[-1]
|
|
47
60
|
wp.launch(
|
|
48
|
-
kernel=
|
|
61
|
+
kernel=_prepare_node_sort_kernel,
|
|
49
62
|
dim=index_count,
|
|
50
|
-
inputs=[sorted_array_indices, indices_per_element],
|
|
63
|
+
inputs=[node_indices.flatten(), sorted_node_indices, sorted_array_indices, indices_per_element],
|
|
51
64
|
)
|
|
52
65
|
|
|
53
66
|
# Sort indices
|
|
@@ -154,8 +167,16 @@ def masked_indices(
|
|
|
154
167
|
|
|
155
168
|
|
|
156
169
|
@wp.kernel
|
|
157
|
-
def
|
|
158
|
-
|
|
170
|
+
def _prepare_node_sort_kernel(
|
|
171
|
+
node_indices: wp.array(dtype=int),
|
|
172
|
+
sort_keys: wp.array(dtype=int),
|
|
173
|
+
sort_values: wp.array(dtype=int),
|
|
174
|
+
divisor: int,
|
|
175
|
+
):
|
|
176
|
+
i = wp.tid()
|
|
177
|
+
node = node_indices[i]
|
|
178
|
+
sort_keys[i] = wp.where(node >= 0, node, NULL_NODE_INDEX)
|
|
179
|
+
sort_values[i] = i // divisor
|
|
159
180
|
|
|
160
181
|
|
|
161
182
|
@wp.kernel
|
warp/jax.py
CHANGED
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2023 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
import warp
|
|
9
17
|
|
|
@@ -50,6 +58,19 @@ def device_from_jax(jax_device) -> warp.context.Device:
|
|
|
50
58
|
raise RuntimeError(f"Unsupported Jax device platform '{jax_device.platform}'")
|
|
51
59
|
|
|
52
60
|
|
|
61
|
+
def get_jax_device():
|
|
62
|
+
"""Get the current Jax device."""
|
|
63
|
+
import jax
|
|
64
|
+
|
|
65
|
+
# TODO: is there a simpler way of getting the Jax "current" device?
|
|
66
|
+
# check if jax.default_device() context manager is active
|
|
67
|
+
device = jax.config.jax_default_device
|
|
68
|
+
# if default device is not set, use first device
|
|
69
|
+
if device is None:
|
|
70
|
+
device = jax.local_devices()[0]
|
|
71
|
+
return device
|
|
72
|
+
|
|
73
|
+
|
|
53
74
|
def dtype_to_jax(warp_dtype):
|
|
54
75
|
"""Return the Jax dtype corresponding to a Warp dtype.
|
|
55
76
|
|
|
@@ -148,7 +169,7 @@ def to_jax(warp_array):
|
|
|
148
169
|
"""
|
|
149
170
|
import jax.dlpack
|
|
150
171
|
|
|
151
|
-
return jax.dlpack.from_dlpack(
|
|
172
|
+
return jax.dlpack.from_dlpack(warp_array)
|
|
152
173
|
|
|
153
174
|
|
|
154
175
|
def from_jax(jax_array, dtype=None) -> warp.array:
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from .custom_call import jax_kernel
|
|
@@ -1,16 +1,23 @@
|
|
|
1
|
-
# Copyright (c) 2024 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
import ctypes
|
|
9
17
|
|
|
10
|
-
import jax
|
|
11
|
-
|
|
12
18
|
import warp as wp
|
|
13
19
|
from warp.context import type_str
|
|
20
|
+
from warp.jax import get_jax_device
|
|
14
21
|
from warp.types import array_t, launch_bounds_t, strides_from_shape
|
|
15
22
|
|
|
16
23
|
_jax_warp_p = None
|
|
@@ -21,35 +28,33 @@ _registered_kernels = [None]
|
|
|
21
28
|
_registered_kernel_to_id = {}
|
|
22
29
|
|
|
23
30
|
|
|
24
|
-
def jax_kernel(
|
|
31
|
+
def jax_kernel(kernel, launch_dims=None):
|
|
25
32
|
"""Create a Jax primitive from a Warp kernel.
|
|
26
33
|
|
|
27
34
|
NOTE: This is an experimental feature under development.
|
|
28
35
|
|
|
29
36
|
Args:
|
|
30
|
-
|
|
37
|
+
kernel: The Warp kernel to be wrapped.
|
|
31
38
|
launch_dims: Optional. Specify the kernel launch dimensions. If None,
|
|
32
39
|
dimensions are inferred from the shape of the first argument.
|
|
33
40
|
This option when set will specify the output dimensions.
|
|
34
41
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
- All arrays must be contiguous.
|
|
41
|
-
- Only the CUDA backend is supported.
|
|
42
|
+
Limitations:
|
|
43
|
+
- All kernel arguments must be contiguous arrays.
|
|
44
|
+
- Input arguments are followed by output arguments in the Warp kernel definition.
|
|
45
|
+
- There must be at least one input argument and at least one output argument.
|
|
46
|
+
- Only the CUDA backend is supported.
|
|
42
47
|
"""
|
|
43
48
|
|
|
44
49
|
if _jax_warp_p is None:
|
|
45
50
|
# Create and register the primitive
|
|
46
51
|
_create_jax_warp_primitive()
|
|
47
|
-
if
|
|
52
|
+
if kernel not in _registered_kernel_to_id:
|
|
48
53
|
id = len(_registered_kernels)
|
|
49
|
-
_registered_kernels.append(
|
|
50
|
-
_registered_kernel_to_id[
|
|
54
|
+
_registered_kernels.append(kernel)
|
|
55
|
+
_registered_kernel_to_id[kernel] = id
|
|
51
56
|
else:
|
|
52
|
-
id = _registered_kernel_to_id[
|
|
57
|
+
id = _registered_kernel_to_id[kernel]
|
|
53
58
|
|
|
54
59
|
def bind(*args):
|
|
55
60
|
return _jax_warp_p.bind(*args, kernel=id, launch_dims=launch_dims)
|
|
@@ -94,7 +99,7 @@ def _warp_custom_callback(stream, buffers, opaque, opaque_len):
|
|
|
94
99
|
kernel_params[i + 1] = arg_ptr
|
|
95
100
|
|
|
96
101
|
# Get current device.
|
|
97
|
-
device = wp.device_from_jax(
|
|
102
|
+
device = wp.device_from_jax(get_jax_device())
|
|
98
103
|
|
|
99
104
|
# Get kernel hooks.
|
|
100
105
|
# Note: module was loaded during jit lowering.
|
|
@@ -107,16 +112,6 @@ def _warp_custom_callback(stream, buffers, opaque, opaque_len):
|
|
|
107
112
|
)
|
|
108
113
|
|
|
109
114
|
|
|
110
|
-
# TODO: is there a simpler way of getting the Jax "current" device?
|
|
111
|
-
def _get_jax_device():
|
|
112
|
-
# check if jax.default_device() context manager is active
|
|
113
|
-
device = jax.config.jax_default_device
|
|
114
|
-
# if default device is not set, use first device
|
|
115
|
-
if device is None:
|
|
116
|
-
device = jax.local_devices()[0]
|
|
117
|
-
return device
|
|
118
|
-
|
|
119
|
-
|
|
120
115
|
def _create_jax_warp_primitive():
|
|
121
116
|
from functools import reduce
|
|
122
117
|
|
|
@@ -280,7 +275,7 @@ def _create_jax_warp_primitive():
|
|
|
280
275
|
# TODO This may not be necessary, but it is perhaps better not to be
|
|
281
276
|
# mucking with kernel loading while already running the workload.
|
|
282
277
|
module = wp_kernel.module
|
|
283
|
-
device = wp.device_from_jax(
|
|
278
|
+
device = wp.device_from_jax(get_jax_device())
|
|
284
279
|
if not module.load(device):
|
|
285
280
|
raise Exception("Could not load kernel on device")
|
|
286
281
|
|