warp-lang 1.6.1__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +21 -7
- warp/autograd.py +14 -6
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +424 -6
- warp/build_dll.py +20 -20
- warp/builtins.py +467 -368
- warp/codegen.py +193 -125
- warp/config.py +56 -12
- warp/constants.py +14 -6
- warp/context.py +524 -277
- warp/dlpack.py +22 -12
- warp/examples/__init__.py +14 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_api.py +14 -6
- warp/examples/benchmarks/benchmark_cloth.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
- warp/examples/benchmarks/benchmark_gemm.py +82 -48
- warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
- warp/examples/benchmarks/benchmark_launches.py +14 -6
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +14 -6
- warp/examples/core/example_cupy.py +14 -6
- warp/examples/core/example_dem.py +14 -6
- warp/examples/core/example_fluid.py +14 -6
- warp/examples/core/example_graph_capture.py +14 -6
- warp/examples/core/example_marching_cubes.py +14 -6
- warp/examples/core/example_mesh.py +14 -6
- warp/examples/core/example_mesh_intersect.py +14 -6
- warp/examples/core/example_nvdb.py +14 -6
- warp/examples/core/example_raycast.py +14 -6
- warp/examples/core/example_raymarch.py +14 -6
- warp/examples/core/example_render_opengl.py +14 -6
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +14 -6
- warp/examples/core/example_torch.py +14 -6
- warp/examples/core/example_wave.py +14 -6
- warp/examples/fem/example_adaptive_grid.py +14 -6
- warp/examples/fem/example_apic_fluid.py +15 -7
- warp/examples/fem/example_burgers.py +16 -8
- warp/examples/fem/example_convection_diffusion.py +14 -6
- warp/examples/fem/example_convection_diffusion_dg.py +14 -6
- warp/examples/fem/example_deformed_geometry.py +15 -7
- warp/examples/fem/example_diffusion.py +14 -6
- warp/examples/fem/example_diffusion_3d.py +14 -6
- warp/examples/fem/example_diffusion_mgpu.py +14 -6
- warp/examples/fem/example_distortion_energy.py +15 -7
- warp/examples/fem/example_magnetostatics.py +20 -12
- warp/examples/fem/example_mixed_elasticity.py +14 -6
- warp/examples/fem/example_navier_stokes.py +14 -6
- warp/examples/fem/example_nonconforming_contact.py +14 -6
- warp/examples/fem/example_stokes.py +14 -6
- warp/examples/fem/example_stokes_transfer.py +14 -6
- warp/examples/fem/example_streamlines.py +14 -6
- warp/examples/fem/utils.py +24 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +14 -6
- warp/examples/optim/example_cloth_throw.py +14 -6
- warp/examples/optim/example_diffray.py +14 -6
- warp/examples/optim/example_drone.py +14 -6
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +14 -6
- warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
- warp/examples/optim/example_softbody_properties.py +14 -6
- warp/examples/optim/example_spring_cage.py +14 -6
- warp/examples/optim/example_trajectory.py +14 -6
- warp/examples/sim/example_cartpole.py +14 -6
- warp/examples/sim/example_cloth.py +14 -6
- warp/examples/sim/example_cloth_self_contact.py +14 -6
- warp/examples/sim/example_granular.py +14 -6
- warp/examples/sim/example_granular_collision_sdf.py +14 -6
- warp/examples/sim/example_jacobian_ik.py +14 -6
- warp/examples/sim/example_particle_chain.py +14 -6
- warp/examples/sim/example_quadruped.py +14 -6
- warp/examples/sim/example_rigid_chain.py +14 -6
- warp/examples/sim/example_rigid_contact.py +14 -6
- warp/examples/sim/example_rigid_force.py +14 -6
- warp/examples/sim/example_rigid_gyroscopic.py +14 -6
- warp/examples/sim/example_rigid_soft_contact.py +14 -6
- warp/examples/sim/example_soft_body.py +14 -6
- warp/examples/tile/example_tile_cholesky.py +14 -6
- warp/examples/tile/example_tile_convolution.py +14 -6
- warp/examples/tile/example_tile_fft.py +14 -6
- warp/examples/tile/example_tile_filtering.py +14 -6
- warp/examples/tile/example_tile_matmul.py +16 -10
- warp/examples/tile/example_tile_mlp.py +14 -6
- warp/examples/tile/example_tile_nbody.py +14 -6
- warp/examples/tile/example_tile_walker.py +14 -6
- warp/fabric.py +15 -0
- warp/fem/__init__.py +26 -1
- warp/fem/adaptivity.py +19 -4
- warp/fem/cache.py +15 -0
- warp/fem/dirichlet.py +15 -0
- warp/fem/domain.py +15 -0
- warp/fem/field/__init__.py +15 -0
- warp/fem/field/field.py +15 -0
- warp/fem/field/nodal_field.py +37 -68
- warp/fem/field/restriction.py +15 -0
- warp/fem/field/virtual.py +77 -23
- warp/fem/geometry/__init__.py +15 -0
- warp/fem/geometry/adaptive_nanogrid.py +24 -10
- warp/fem/geometry/closest_point.py +16 -1
- warp/fem/geometry/deformed_geometry.py +20 -2
- warp/fem/geometry/element.py +15 -0
- warp/fem/geometry/geometry.py +20 -0
- warp/fem/geometry/grid_2d.py +27 -12
- warp/fem/geometry/grid_3d.py +27 -15
- warp/fem/geometry/hexmesh.py +20 -7
- warp/fem/geometry/nanogrid.py +24 -11
- warp/fem/geometry/partition.py +15 -0
- warp/fem/geometry/quadmesh.py +28 -13
- warp/fem/geometry/tetmesh.py +18 -4
- warp/fem/geometry/trimesh.py +18 -8
- warp/fem/integrate.py +277 -93
- warp/fem/linalg.py +20 -5
- warp/fem/operator.py +15 -0
- warp/fem/polynomial.py +15 -0
- warp/fem/quadrature/__init__.py +15 -0
- warp/fem/quadrature/pic_quadrature.py +52 -22
- warp/fem/quadrature/quadrature.py +209 -25
- warp/fem/space/__init__.py +16 -1
- warp/fem/space/basis_function_space.py +19 -2
- warp/fem/space/basis_space.py +40 -18
- warp/fem/space/dof_mapper.py +15 -0
- warp/fem/space/function_space.py +15 -0
- warp/fem/space/grid_2d_function_space.py +15 -0
- warp/fem/space/grid_3d_function_space.py +15 -0
- warp/fem/space/hexmesh_function_space.py +17 -2
- warp/fem/space/nanogrid_function_space.py +15 -0
- warp/fem/space/partition.py +21 -2
- warp/fem/space/quadmesh_function_space.py +23 -8
- warp/fem/space/restriction.py +15 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +38 -23
- warp/fem/space/shape/shape_function.py +15 -0
- warp/fem/space/shape/square_shape_function.py +27 -12
- warp/fem/space/shape/tet_shape_function.py +15 -0
- warp/fem/space/shape/triangle_shape_function.py +16 -1
- warp/fem/space/tetmesh_function_space.py +18 -3
- warp/fem/space/topology.py +15 -0
- warp/fem/space/trimesh_function_space.py +17 -2
- warp/fem/types.py +15 -0
- warp/fem/utils.py +27 -6
- warp/jax.py +28 -7
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +103 -6
- warp/native/array.h +28 -6
- warp/native/builtin.h +44 -9
- warp/native/bvh.cpp +18 -7
- warp/native/bvh.cu +57 -20
- warp/native/bvh.h +17 -7
- warp/native/clang/clang.cpp +45 -9
- warp/native/coloring.cpp +15 -6
- warp/native/crt.cpp +15 -6
- warp/native/crt.h +15 -6
- warp/native/cuda_crt.h +15 -6
- warp/native/cuda_util.cpp +29 -6
- warp/native/cuda_util.h +17 -6
- warp/native/error.cpp +15 -6
- warp/native/error.h +15 -6
- warp/native/exports.h +85 -63
- warp/native/fabric.h +15 -6
- warp/native/hashgrid.cpp +15 -6
- warp/native/hashgrid.cu +15 -6
- warp/native/hashgrid.h +15 -6
- warp/native/initializer_array.h +15 -6
- warp/native/intersect.h +41 -32
- warp/native/intersect_adj.h +48 -39
- warp/native/intersect_tri.h +17 -0
- warp/native/marching.cpp +16 -0
- warp/native/marching.cu +16 -7
- warp/native/marching.h +17 -0
- warp/native/mat.h +528 -15
- warp/native/mathdx.cpp +15 -6
- warp/native/matnn.h +15 -6
- warp/native/mesh.cpp +15 -6
- warp/native/mesh.cu +15 -6
- warp/native/mesh.h +25 -16
- warp/native/noise.h +15 -6
- warp/native/quat.h +114 -17
- warp/native/rand.h +21 -6
- warp/native/range.h +15 -6
- warp/native/reduce.cpp +15 -6
- warp/native/reduce.cu +15 -6
- warp/native/runlength_encode.cpp +15 -6
- warp/native/runlength_encode.cu +15 -6
- warp/native/scan.cpp +15 -6
- warp/native/scan.cu +15 -6
- warp/native/scan.h +15 -6
- warp/native/solid_angle.h +17 -0
- warp/native/sort.cpp +137 -65
- warp/native/sort.cu +167 -21
- warp/native/sort.h +23 -7
- warp/native/sparse.cpp +58 -28
- warp/native/sparse.cu +67 -23
- warp/native/spatial.h +15 -6
- warp/native/svd.h +131 -6
- warp/native/temp_buffer.h +15 -6
- warp/native/tile.h +316 -111
- warp/native/tile_reduce.h +61 -9
- warp/native/vec.h +83 -13
- warp/native/volume.cpp +100 -119
- warp/native/volume.cu +15 -6
- warp/native/volume.h +15 -6
- warp/native/volume_builder.cu +40 -16
- warp/native/volume_builder.h +21 -6
- warp/native/volume_impl.h +15 -6
- warp/native/warp.cpp +20 -12
- warp/native/warp.cu +114 -16
- warp/native/warp.h +34 -16
- warp/optim/__init__.py +14 -6
- warp/optim/adam.py +14 -6
- warp/optim/linear.py +25 -10
- warp/optim/sgd.py +14 -6
- warp/paddle.py +14 -6
- warp/render/__init__.py +14 -6
- warp/render/render_opengl.py +14 -6
- warp/render/render_usd.py +14 -6
- warp/render/utils.py +14 -6
- warp/sim/__init__.py +14 -7
- warp/sim/articulation.py +18 -10
- warp/sim/collide.py +35 -16
- warp/sim/graph_coloring.py +14 -6
- warp/sim/import_mjcf.py +463 -162
- warp/sim/import_snu.py +14 -7
- warp/sim/import_urdf.py +46 -18
- warp/sim/import_usd.py +14 -7
- warp/sim/inertia.py +14 -6
- warp/sim/integrator.py +14 -6
- warp/sim/integrator_euler.py +19 -11
- warp/sim/integrator_featherstone.py +17 -16
- warp/sim/integrator_vbd.py +222 -8
- warp/sim/integrator_xpbd.py +19 -11
- warp/sim/model.py +56 -19
- warp/sim/particles.py +14 -6
- warp/sim/render.py +14 -6
- warp/sim/utils.py +17 -2
- warp/sparse.py +657 -555
- warp/stubs.py +231 -19
- warp/tape.py +14 -6
- warp/tests/aux_test_class_kernel.py +14 -6
- warp/tests/aux_test_compile_consts_dummy.py +14 -6
- warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
- warp/tests/aux_test_dependent.py +14 -6
- warp/tests/aux_test_grad_customs.py +14 -6
- warp/tests/aux_test_instancing_gc.py +14 -6
- warp/tests/aux_test_module_unload.py +14 -6
- warp/tests/aux_test_name_clash1.py +14 -6
- warp/tests/aux_test_name_clash2.py +14 -6
- warp/tests/aux_test_unresolved_func.py +14 -6
- warp/tests/aux_test_unresolved_symbol.py +14 -6
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_async.py → cuda/test_async.py} +14 -6
- warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
- warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
- warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
- warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
- warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
- warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
- warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
- warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
- warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
- warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
- warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
- warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
- warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
- warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
- warp/tests/run_coverage_serial.py +14 -6
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
- warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
- warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
- warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
- warp/tests/{test_model.py → sim/test_model.py} +55 -7
- warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +14 -6
- warp/tests/test_arithmetic.py +14 -6
- warp/tests/test_array.py +14 -6
- warp/tests/test_array_reduce.py +14 -6
- warp/tests/test_assert.py +14 -6
- warp/tests/test_atomic.py +14 -6
- warp/tests/test_bool.py +15 -7
- warp/tests/test_builtins_resolution.py +14 -6
- warp/tests/test_closest_point_edge_edge.py +14 -6
- warp/tests/test_codegen.py +14 -6
- warp/tests/test_codegen_instancing.py +14 -6
- warp/tests/test_compile_consts.py +14 -6
- warp/tests/test_conditional.py +14 -6
- warp/tests/test_context.py +14 -6
- warp/tests/test_copy.py +14 -6
- warp/tests/test_ctypes.py +14 -6
- warp/tests/test_dense.py +14 -6
- warp/tests/test_devices.py +14 -6
- warp/tests/test_examples.py +42 -42
- warp/tests/test_fabricarray.py +14 -6
- warp/tests/test_fast_math.py +14 -6
- warp/tests/test_fem.py +37 -10
- warp/tests/test_fp16.py +14 -6
- warp/tests/test_func.py +14 -6
- warp/tests/test_future_annotations.py +14 -6
- warp/tests/test_generics.py +14 -6
- warp/tests/test_grad.py +14 -6
- warp/tests/test_grad_customs.py +14 -6
- warp/tests/test_grad_debug.py +14 -6
- warp/tests/test_implicit_init.py +14 -6
- warp/tests/test_import.py +14 -6
- warp/tests/test_indexedarray.py +14 -6
- warp/tests/test_intersect.py +14 -6
- warp/tests/test_iter.py +14 -6
- warp/tests/test_large.py +14 -6
- warp/tests/test_launch.py +14 -6
- warp/tests/test_lerp.py +14 -6
- warp/tests/test_linear_solvers.py +15 -11
- warp/tests/test_lvalue.py +14 -6
- warp/tests/test_mat.py +247 -85
- warp/tests/test_mat_lite.py +14 -6
- warp/tests/test_mat_scalar_ops.py +18 -10
- warp/tests/test_math.py +14 -6
- warp/tests/test_mlp.py +14 -6
- warp/tests/test_module_hashing.py +14 -6
- warp/tests/test_modules_lite.py +14 -6
- warp/tests/test_noise.py +14 -6
- warp/tests/test_operators.py +14 -6
- warp/tests/test_options.py +14 -6
- warp/tests/test_overwrite.py +15 -60
- warp/tests/test_print.py +14 -6
- warp/tests/test_quat.py +81 -52
- warp/tests/test_rand.py +58 -43
- warp/tests/test_reload.py +14 -6
- warp/tests/test_rounding.py +14 -6
- warp/tests/test_runlength_encode.py +14 -6
- warp/tests/test_scalar_ops.py +14 -6
- warp/tests/test_smoothstep.py +14 -6
- warp/tests/test_snippet.py +15 -0
- warp/tests/test_sparse.py +61 -12
- warp/tests/test_spatial.py +89 -6
- warp/tests/test_special_values.py +14 -6
- warp/tests/test_static.py +15 -7
- warp/tests/test_struct.py +14 -6
- warp/tests/test_tape.py +14 -6
- warp/tests/test_transient_module.py +14 -6
- warp/tests/test_triangle_closest_point.py +14 -6
- warp/tests/test_types.py +14 -6
- warp/tests/test_utils.py +98 -10
- warp/tests/test_vec.py +60 -40
- warp/tests/test_vec_lite.py +14 -6
- warp/tests/test_vec_scalar_ops.py +14 -6
- warp/tests/test_verify_fp.py +14 -6
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
- warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
- warp/tests/unittest_serial.py +15 -6
- warp/tests/unittest_suites.py +59 -65
- warp/tests/unittest_utils.py +16 -7
- warp/tests/walkthrough_debug.py +14 -6
- warp/thirdparty/unittest_parallel.py +15 -8
- warp/torch.py +14 -6
- warp/types.py +124 -664
- warp/utils.py +151 -78
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp/examples/optim/example_walker.py +0 -309
- warp/native/cutlass_gemm.cpp +0 -34
- warp/native/cutlass_gemm.cu +0 -373
- warp/tests/test_matmul.py +0 -503
- warp/tests/test_matmul_lite.py +0 -403
- warp/tests/test_vbd.py +0 -378
- warp/tests/unused_test_misc.py +0 -69
- warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
- warp_lang-1.6.1.dist-info/RECORD +0 -419
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/builtins.py
CHANGED
|
@@ -1,15 +1,24 @@
|
|
|
1
|
-
# Copyright (c) 2022 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
7
16
|
import builtins
|
|
8
17
|
import functools
|
|
9
|
-
import tempfile
|
|
10
|
-
from pathlib import Path
|
|
11
18
|
from typing import Any, Callable, Mapping, Sequence
|
|
12
19
|
|
|
20
|
+
import warp.build
|
|
21
|
+
import warp.context
|
|
13
22
|
from warp.codegen import Reference, Var, strip_reference
|
|
14
23
|
from warp.types import *
|
|
15
24
|
|
|
@@ -32,7 +41,7 @@ def sametypes(arg_types: Mapping[str, Any]):
|
|
|
32
41
|
return all(types_equal(arg_type_0, t) for t in arg_types_iter)
|
|
33
42
|
|
|
34
43
|
|
|
35
|
-
def sametypes_create_value_func(default):
|
|
44
|
+
def sametypes_create_value_func(default: TypeVar):
|
|
36
45
|
def fn(arg_types, arg_values):
|
|
37
46
|
if arg_types is None:
|
|
38
47
|
return default
|
|
@@ -390,7 +399,7 @@ add_builtin(
|
|
|
390
399
|
)
|
|
391
400
|
|
|
392
401
|
|
|
393
|
-
def scalar_infer_type(arg_types: Mapping[str, type]):
|
|
402
|
+
def scalar_infer_type(arg_types: Union[Mapping[str, type], Tuple[type, ...], None]):
|
|
394
403
|
if arg_types is None:
|
|
395
404
|
return Scalar
|
|
396
405
|
|
|
@@ -941,6 +950,12 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
941
950
|
raise RuntimeError("the `shape` argument must be specified when initializing a matrix by value")
|
|
942
951
|
|
|
943
952
|
if all(type_is_vector(x) for x in variadic_arg_types):
|
|
953
|
+
warp.utils.warn(
|
|
954
|
+
"the built-in `wp.matrix()` won't support taking column vectors as input "
|
|
955
|
+
"in the future. Use `wp.matrix_from_rows()` or `wp.matrix_from_cols()` instead.",
|
|
956
|
+
DeprecationWarning,
|
|
957
|
+
)
|
|
958
|
+
|
|
944
959
|
if shape[1] != variadic_arg_count:
|
|
945
960
|
raise RuntimeError(
|
|
946
961
|
f"incompatible number of column vectors given ({variadic_arg_count}) "
|
|
@@ -1021,6 +1036,86 @@ add_builtin(
|
|
|
1021
1036
|
)
|
|
1022
1037
|
|
|
1023
1038
|
|
|
1039
|
+
def matrix_from_vecs_create_value_func(cols: bool):
|
|
1040
|
+
def fn(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1041
|
+
if arg_types is None:
|
|
1042
|
+
return matrix(shape=(Any, Any), dtype=Scalar)
|
|
1043
|
+
|
|
1044
|
+
variadic_arg_types = arg_types.get("args", ())
|
|
1045
|
+
variadic_arg_count = len(variadic_arg_types)
|
|
1046
|
+
|
|
1047
|
+
if not all(type_is_vector(x) for x in variadic_arg_types):
|
|
1048
|
+
raise RuntimeError("all arguments are expected to be vectors")
|
|
1049
|
+
|
|
1050
|
+
length = variadic_arg_types[0]._length_
|
|
1051
|
+
if any(x._length_ != length for x in variadic_arg_types):
|
|
1052
|
+
raise RuntimeError("all vectors are expected to have the same length")
|
|
1053
|
+
|
|
1054
|
+
dtype = variadic_arg_types[0]._wp_scalar_type_
|
|
1055
|
+
if any(x._wp_scalar_type_ != dtype for x in variadic_arg_types):
|
|
1056
|
+
raise RuntimeError("all vectors are expected to have the same dtype")
|
|
1057
|
+
|
|
1058
|
+
shape = (length, variadic_arg_count) if cols else (variadic_arg_count, length)
|
|
1059
|
+
return matrix(shape=shape, dtype=dtype)
|
|
1060
|
+
|
|
1061
|
+
return fn
|
|
1062
|
+
|
|
1063
|
+
|
|
1064
|
+
def matrix_from_vecs_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
1065
|
+
# We're in the codegen stage where we emit the code calling the built-in.
|
|
1066
|
+
# Further validate the given argument values if needed and map them
|
|
1067
|
+
# to the underlying C++ function's runtime and template params.
|
|
1068
|
+
|
|
1069
|
+
shape = return_type._shape_
|
|
1070
|
+
dtype = return_type._wp_scalar_type_
|
|
1071
|
+
|
|
1072
|
+
variadic_args = args.get("args", ())
|
|
1073
|
+
|
|
1074
|
+
func_args = variadic_args
|
|
1075
|
+
|
|
1076
|
+
if shape in ((2, 2), (3, 3), (4, 4)):
|
|
1077
|
+
# Template specializations exist for these shapes, don't pass them
|
|
1078
|
+
# as template parameters.
|
|
1079
|
+
template_args = (dtype,)
|
|
1080
|
+
else:
|
|
1081
|
+
template_args = (*shape, dtype)
|
|
1082
|
+
|
|
1083
|
+
return (func_args, template_args)
|
|
1084
|
+
|
|
1085
|
+
|
|
1086
|
+
def matrix_from_vecs_initializer_list_func(args, return_type):
|
|
1087
|
+
shape = return_type._shape_
|
|
1088
|
+
|
|
1089
|
+
return shape[0] != shape[1] or shape[0] > 4
|
|
1090
|
+
|
|
1091
|
+
|
|
1092
|
+
add_builtin(
|
|
1093
|
+
"matrix_from_cols",
|
|
1094
|
+
input_types={"*args": vector(length=Any, dtype=Scalar)},
|
|
1095
|
+
variadic=True,
|
|
1096
|
+
value_func=matrix_from_vecs_create_value_func(cols=True),
|
|
1097
|
+
dispatch_func=matrix_from_vecs_dispatch_func,
|
|
1098
|
+
initializer_list_func=matrix_from_vecs_initializer_list_func,
|
|
1099
|
+
native_func="matrix_from_cols",
|
|
1100
|
+
doc="Construct a matrix from column vectors.",
|
|
1101
|
+
group="Vector Math",
|
|
1102
|
+
export=False,
|
|
1103
|
+
)
|
|
1104
|
+
|
|
1105
|
+
add_builtin(
|
|
1106
|
+
"matrix_from_rows",
|
|
1107
|
+
input_types={"*args": vector(length=Any, dtype=Scalar)},
|
|
1108
|
+
variadic=True,
|
|
1109
|
+
value_func=matrix_from_vecs_create_value_func(cols=False),
|
|
1110
|
+
dispatch_func=matrix_from_vecs_dispatch_func,
|
|
1111
|
+
initializer_list_func=matrix_from_vecs_initializer_list_func,
|
|
1112
|
+
native_func="matrix_from_rows",
|
|
1113
|
+
doc="Construct a matrix from row vectors.",
|
|
1114
|
+
group="Vector Math",
|
|
1115
|
+
export=False,
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
|
|
1024
1119
|
def identity_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1025
1120
|
if arg_types is None:
|
|
1026
1121
|
return matrix(shape=(Any, Any), dtype=Scalar)
|
|
@@ -1132,6 +1227,21 @@ add_builtin(
|
|
|
1132
1227
|
while the left and right basis vectors are returned in ``U`` and ``V``.""",
|
|
1133
1228
|
)
|
|
1134
1229
|
|
|
1230
|
+
add_builtin(
|
|
1231
|
+
"svd2",
|
|
1232
|
+
input_types={
|
|
1233
|
+
"A": matrix(shape=(2, 2), dtype=Float),
|
|
1234
|
+
"U": matrix(shape=(2, 2), dtype=Float),
|
|
1235
|
+
"sigma": vector(length=2, dtype=Float),
|
|
1236
|
+
"V": matrix(shape=(2, 2), dtype=Scalar),
|
|
1237
|
+
},
|
|
1238
|
+
value_type=None,
|
|
1239
|
+
group="Vector Math",
|
|
1240
|
+
export=False,
|
|
1241
|
+
doc="""Compute the SVD of a 2x2 matrix ``A``. The singular values are returned in ``sigma``,
|
|
1242
|
+
while the left and right basis vectors are returned in ``U`` and ``V``.""",
|
|
1243
|
+
)
|
|
1244
|
+
|
|
1135
1245
|
add_builtin(
|
|
1136
1246
|
"qr3",
|
|
1137
1247
|
input_types={
|
|
@@ -1323,7 +1433,18 @@ add_builtin(
|
|
|
1323
1433
|
input_types={"mat": matrix(shape=(3, 3), dtype=Float)},
|
|
1324
1434
|
value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
|
|
1325
1435
|
group="Quaternion Math",
|
|
1326
|
-
doc="Construct a quaternion from a 3x3 matrix.
|
|
1436
|
+
doc="""Construct a quaternion from a 3x3 matrix.
|
|
1437
|
+
|
|
1438
|
+
If the matrix is not a pure rotation, but for example includes scaling or skewing, the result is undefined.""",
|
|
1439
|
+
)
|
|
1440
|
+
add_builtin(
|
|
1441
|
+
"quat_from_matrix",
|
|
1442
|
+
input_types={"mat": matrix(shape=(4, 4), dtype=Float)},
|
|
1443
|
+
value_func=lambda arg_types, arg_values: quaternion(dtype=float_infer_type(arg_types)),
|
|
1444
|
+
group="Quaternion Math",
|
|
1445
|
+
doc="""Construct a quaternion from a 4x4 matrix.
|
|
1446
|
+
|
|
1447
|
+
If the top-left 3x3 block of the matrix is not a pure rotation, but for example includes scaling or skewing, the result is undefined.""",
|
|
1327
1448
|
)
|
|
1328
1449
|
add_builtin(
|
|
1329
1450
|
"quat_rpy",
|
|
@@ -2366,7 +2487,7 @@ add_builtin(
|
|
|
2366
2487
|
|
|
2367
2488
|
This function converts values computed using scalar kernel code to a tile representation for input into collective operations.
|
|
2368
2489
|
|
|
2369
|
-
* If the input value is a scalar, then the resulting tile has ``shape=(
|
|
2490
|
+
* If the input value is a scalar, then the resulting tile has ``shape=(block_dim,)``
|
|
2370
2491
|
* If the input value is a vector, then the resulting tile has ``shape=(length(vector), block_dim)``
|
|
2371
2492
|
|
|
2372
2493
|
:param x: A per-thread local value, e.g. scalar, vector, or matrix.
|
|
@@ -2660,11 +2781,9 @@ def tile_broadcast_value_func(arg_types, arg_values):
|
|
|
2660
2781
|
def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2661
2782
|
tile = arg_values["a"]
|
|
2662
2783
|
|
|
2663
|
-
|
|
2664
|
-
|
|
2665
|
-
template_args.
|
|
2666
|
-
template_args.append(return_type.strides[0])
|
|
2667
|
-
template_args.append(return_type.strides[1])
|
|
2784
|
+
assert len(return_type.shape) == len(return_type.strides)
|
|
2785
|
+
assert 1 <= len(return_type.shape) <= 4
|
|
2786
|
+
template_args = [*return_type.shape, *return_type.strides]
|
|
2668
2787
|
|
|
2669
2788
|
return ((tile,), template_args)
|
|
2670
2789
|
|
|
@@ -2677,56 +2796,17 @@ add_builtin(
|
|
|
2677
2796
|
variadic=False,
|
|
2678
2797
|
doc="""Broadcast a tile.
|
|
2679
2798
|
|
|
2680
|
-
|
|
2681
|
-
|
|
2799
|
+
Broadcasts the input tile ``a`` to the destination shape.
|
|
2682
2800
|
Broadcasting follows NumPy broadcast rules.
|
|
2683
2801
|
|
|
2684
2802
|
:param a: Tile to broadcast
|
|
2685
2803
|
:param shape: The shape to broadcast to
|
|
2686
|
-
:returns: Tile with broadcast
|
|
2804
|
+
:returns: Tile with broadcast shape""",
|
|
2687
2805
|
group="Tile Primitives",
|
|
2688
2806
|
export=False,
|
|
2689
2807
|
)
|
|
2690
2808
|
|
|
2691
2809
|
|
|
2692
|
-
def tile_matmul_value_func(arg_types, arg_values):
|
|
2693
|
-
# return generic type (for doc builds)
|
|
2694
|
-
if arg_types is None:
|
|
2695
|
-
return Tile(dtype=Any, shape=Any)
|
|
2696
|
-
|
|
2697
|
-
if len(arg_types) != 3:
|
|
2698
|
-
raise TypeError(f"tile_matmul() takes exactly 3 positional arguments but {len(arg_types)} were given")
|
|
2699
|
-
|
|
2700
|
-
return None
|
|
2701
|
-
|
|
2702
|
-
|
|
2703
|
-
def tile_matmul_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2704
|
-
a = arg_values["a"]
|
|
2705
|
-
b = arg_values["b"]
|
|
2706
|
-
out = arg_values["out"]
|
|
2707
|
-
|
|
2708
|
-
# force the storage type of the input variables to shared memory
|
|
2709
|
-
a.type.storage = "shared"
|
|
2710
|
-
b.type.storage = "shared"
|
|
2711
|
-
out.type.storage = "shared"
|
|
2712
|
-
|
|
2713
|
-
template_args = []
|
|
2714
|
-
return ((a, b, out), template_args)
|
|
2715
|
-
|
|
2716
|
-
|
|
2717
|
-
add_builtin(
|
|
2718
|
-
"tile_matmul_scalar",
|
|
2719
|
-
input_types={"a": Tile, "b": Tile, "out": Tile},
|
|
2720
|
-
value_func=tile_matmul_value_func,
|
|
2721
|
-
dispatch_func=tile_matmul_dispatch_func,
|
|
2722
|
-
variadic=True,
|
|
2723
|
-
doc="Compute matrix product and accumulate out += a*b.",
|
|
2724
|
-
group="Tile Primitives",
|
|
2725
|
-
hidden=True,
|
|
2726
|
-
export=False,
|
|
2727
|
-
)
|
|
2728
|
-
|
|
2729
|
-
|
|
2730
2810
|
def tile_sum_value_func(arg_types, arg_values):
|
|
2731
2811
|
# return generic type (for doc builds)
|
|
2732
2812
|
if arg_types is None:
|
|
@@ -3021,7 +3101,7 @@ def tile_binary_map_value_func(arg_types, arg_values):
|
|
|
3021
3101
|
|
|
3022
3102
|
for i in range(len(a.shape)):
|
|
3023
3103
|
if a.shape[i] != b.shape[i]:
|
|
3024
|
-
raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape
|
|
3104
|
+
raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape} and {b.shape}")
|
|
3025
3105
|
|
|
3026
3106
|
return TileBinaryMap(a, b)
|
|
3027
3107
|
|
|
@@ -3798,6 +3878,18 @@ _volume_supported_value_types = {
|
|
|
3798
3878
|
}
|
|
3799
3879
|
|
|
3800
3880
|
|
|
3881
|
+
def _is_volume_type_supported(dtype):
|
|
3882
|
+
for typ in _volume_supported_value_types:
|
|
3883
|
+
if types_equal(typ, dtype):
|
|
3884
|
+
return True
|
|
3885
|
+
return False
|
|
3886
|
+
|
|
3887
|
+
|
|
3888
|
+
def _check_volume_type_is_supported(dtype):
|
|
3889
|
+
if not _is_volume_type_supported(dtype):
|
|
3890
|
+
raise RuntimeError(f"unsupported volume type `{type_repr(dtype)}`")
|
|
3891
|
+
|
|
3892
|
+
|
|
3801
3893
|
def check_volume_value_grad_compatibility(dtype, grad_dtype):
|
|
3802
3894
|
if type_is_vector(dtype):
|
|
3803
3895
|
expected = matrix(shape=(type_length(dtype), 3), dtype=type_scalar_type(dtype))
|
|
@@ -3813,9 +3905,7 @@ def volume_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
|
|
|
3813
3905
|
return Any
|
|
3814
3906
|
|
|
3815
3907
|
dtype = arg_values["dtype"]
|
|
3816
|
-
|
|
3817
|
-
if dtype not in _volume_supported_value_types:
|
|
3818
|
-
raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
|
|
3908
|
+
_check_volume_type_is_supported(dtype)
|
|
3819
3909
|
|
|
3820
3910
|
return dtype
|
|
3821
3911
|
|
|
@@ -3851,9 +3941,7 @@ def volume_sample_grad_value_func(arg_types: Mapping[str, type], arg_values: Map
|
|
|
3851
3941
|
return Any
|
|
3852
3942
|
|
|
3853
3943
|
dtype = arg_values["dtype"]
|
|
3854
|
-
|
|
3855
|
-
if dtype not in _volume_supported_value_types:
|
|
3856
|
-
raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
|
|
3944
|
+
_check_volume_type_is_supported(dtype)
|
|
3857
3945
|
|
|
3858
3946
|
check_volume_value_grad_compatibility(dtype, arg_types["grad"])
|
|
3859
3947
|
|
|
@@ -3891,9 +3979,7 @@ def volume_lookup_value_func(arg_types: Mapping[str, type], arg_values: Mapping[
|
|
|
3891
3979
|
return Any
|
|
3892
3980
|
|
|
3893
3981
|
dtype = arg_values["dtype"]
|
|
3894
|
-
|
|
3895
|
-
if dtype not in _volume_supported_value_types:
|
|
3896
|
-
raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
|
|
3982
|
+
_check_volume_type_is_supported(dtype)
|
|
3897
3983
|
|
|
3898
3984
|
return dtype
|
|
3899
3985
|
|
|
@@ -3930,9 +4016,7 @@ def volume_store_value_func(arg_types: Mapping[str, type], arg_values: Mapping[s
|
|
|
3930
4016
|
return None
|
|
3931
4017
|
|
|
3932
4018
|
dtype = arg_types["value"]
|
|
3933
|
-
|
|
3934
|
-
if dtype not in _volume_supported_value_types:
|
|
3935
|
-
raise RuntimeError(f"unsupported volume type `{dtype.__name__}`")
|
|
4019
|
+
_check_volume_type_is_supported(dtype)
|
|
3936
4020
|
|
|
3937
4021
|
return None
|
|
3938
4022
|
|
|
@@ -4182,6 +4266,20 @@ add_builtin(
|
|
|
4182
4266
|
group="Random",
|
|
4183
4267
|
doc="Return a random integer between [low, high).",
|
|
4184
4268
|
)
|
|
4269
|
+
add_builtin(
|
|
4270
|
+
"randu",
|
|
4271
|
+
input_types={"state": uint32},
|
|
4272
|
+
value_type=uint32,
|
|
4273
|
+
group="Random",
|
|
4274
|
+
doc="Return a random unsigned integer in the range [0, 2^32).",
|
|
4275
|
+
)
|
|
4276
|
+
add_builtin(
|
|
4277
|
+
"randu",
|
|
4278
|
+
input_types={"state": uint32, "low": uint32, "high": uint32},
|
|
4279
|
+
value_type=uint32,
|
|
4280
|
+
group="Random",
|
|
4281
|
+
doc="Return a random unsigned integer between [low, high).",
|
|
4282
|
+
)
|
|
4185
4283
|
add_builtin(
|
|
4186
4284
|
"randf",
|
|
4187
4285
|
input_types={"state": uint32},
|
|
@@ -4490,11 +4588,31 @@ add_builtin(
|
|
|
4490
4588
|
export=False,
|
|
4491
4589
|
group="Utility",
|
|
4492
4590
|
)
|
|
4591
|
+
|
|
4592
|
+
|
|
4593
|
+
def select_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
4594
|
+
warp.utils.warn(
|
|
4595
|
+
"wp.select() is deprecated and will be removed in a future\n"
|
|
4596
|
+
"version. Use wp.where(cond, value_if_true, value_if_false) instead.",
|
|
4597
|
+
category=DeprecationWarning,
|
|
4598
|
+
)
|
|
4599
|
+
|
|
4600
|
+
func_args = tuple(args.values())
|
|
4601
|
+
template_args = ()
|
|
4602
|
+
|
|
4603
|
+
return (func_args, template_args)
|
|
4604
|
+
|
|
4605
|
+
|
|
4493
4606
|
add_builtin(
|
|
4494
4607
|
"select",
|
|
4495
4608
|
input_types={"cond": builtins.bool, "value_if_false": Any, "value_if_true": Any},
|
|
4496
4609
|
value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
|
|
4497
|
-
|
|
4610
|
+
dispatch_func=select_dispatch_func,
|
|
4611
|
+
doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
4612
|
+
|
|
4613
|
+
.. deprecated:: 1.7
|
|
4614
|
+
Use :func:`where` instead, which has the more intuitive argument order:
|
|
4615
|
+
``where(cond, value_if_true, value_if_false)``.""",
|
|
4498
4616
|
group="Utility",
|
|
4499
4617
|
)
|
|
4500
4618
|
for t in int_types:
|
|
@@ -4502,14 +4620,47 @@ for t in int_types:
|
|
|
4502
4620
|
"select",
|
|
4503
4621
|
input_types={"cond": t, "value_if_false": Any, "value_if_true": Any},
|
|
4504
4622
|
value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
|
|
4505
|
-
|
|
4623
|
+
dispatch_func=select_dispatch_func,
|
|
4624
|
+
doc="""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
4625
|
+
|
|
4626
|
+
.. deprecated:: 1.7
|
|
4627
|
+
Use :func:`where` instead, which has the more intuitive argument order:
|
|
4628
|
+
``where(cond, value_if_true, value_if_false)``.""",
|
|
4506
4629
|
group="Utility",
|
|
4507
4630
|
)
|
|
4508
4631
|
add_builtin(
|
|
4509
4632
|
"select",
|
|
4510
4633
|
input_types={"arr": array(dtype=Any), "value_if_false": Any, "value_if_true": Any},
|
|
4511
4634
|
value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
|
|
4512
|
-
|
|
4635
|
+
dispatch_func=select_dispatch_func,
|
|
4636
|
+
doc="""Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
4637
|
+
|
|
4638
|
+
.. deprecated:: 1.7
|
|
4639
|
+
Use :func:`where` instead, which has the more intuitive argument order:
|
|
4640
|
+
``where(arr, value_if_true, value_if_false)``.""",
|
|
4641
|
+
group="Utility",
|
|
4642
|
+
)
|
|
4643
|
+
|
|
4644
|
+
add_builtin(
|
|
4645
|
+
"where",
|
|
4646
|
+
input_types={"cond": builtins.bool, "value_if_true": Any, "value_if_false": Any},
|
|
4647
|
+
value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
|
|
4648
|
+
doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
|
|
4649
|
+
group="Utility",
|
|
4650
|
+
)
|
|
4651
|
+
for t in int_types:
|
|
4652
|
+
add_builtin(
|
|
4653
|
+
"where",
|
|
4654
|
+
input_types={"cond": t, "value_if_true": Any, "value_if_false": Any},
|
|
4655
|
+
value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
|
|
4656
|
+
doc="Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``.",
|
|
4657
|
+
group="Utility",
|
|
4658
|
+
)
|
|
4659
|
+
add_builtin(
|
|
4660
|
+
"where",
|
|
4661
|
+
input_types={"arr": array(dtype=Any), "value_if_true": Any, "value_if_false": Any},
|
|
4662
|
+
value_func=lambda arg_types, arg_values: Any if arg_types is None else arg_types["value_if_false"],
|
|
4663
|
+
doc="Select between two arguments, if ``arr`` is not null then return ``value_if_true``, otherwise return ``value_if_false``.",
|
|
4513
4664
|
group="Utility",
|
|
4514
4665
|
)
|
|
4515
4666
|
|
|
@@ -5103,33 +5254,51 @@ add_builtin(
|
|
|
5103
5254
|
)
|
|
5104
5255
|
|
|
5105
5256
|
|
|
5257
|
+
# implements vector[index] = value
|
|
5258
|
+
add_builtin(
|
|
5259
|
+
"assign_inplace",
|
|
5260
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
|
|
5261
|
+
value_type=None,
|
|
5262
|
+
hidden=True,
|
|
5263
|
+
group="Utility",
|
|
5264
|
+
)
|
|
5265
|
+
|
|
5266
|
+
# implements quaternion[index] = value
|
|
5267
|
+
add_builtin(
|
|
5268
|
+
"assign_inplace",
|
|
5269
|
+
input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
|
|
5270
|
+
value_type=None,
|
|
5271
|
+
hidden=True,
|
|
5272
|
+
group="Utility",
|
|
5273
|
+
)
|
|
5274
|
+
|
|
5275
|
+
|
|
5106
5276
|
def vector_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
5107
5277
|
vec_type = arg_types["a"]
|
|
5108
5278
|
return vec_type
|
|
5109
5279
|
|
|
5110
5280
|
|
|
5111
|
-
# implements vector[index] = value
|
|
5281
|
+
# implements vector[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
|
|
5112
5282
|
add_builtin(
|
|
5113
|
-
"
|
|
5283
|
+
"assign_copy",
|
|
5114
5284
|
input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
|
|
5115
5285
|
value_func=vector_assign_value_func,
|
|
5116
5286
|
hidden=True,
|
|
5117
5287
|
group="Utility",
|
|
5118
5288
|
)
|
|
5119
5289
|
|
|
5120
|
-
# implements quaternion[index] = value
|
|
5290
|
+
# implements quaternion[index] = value, performs a copy internally if wp.config.enable_vector_component_overwrites is True
|
|
5121
5291
|
add_builtin(
|
|
5122
|
-
"
|
|
5292
|
+
"assign_copy",
|
|
5123
5293
|
input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
|
|
5124
5294
|
value_func=vector_assign_value_func,
|
|
5125
5295
|
hidden=True,
|
|
5126
5296
|
group="Utility",
|
|
5127
5297
|
)
|
|
5128
5298
|
|
|
5129
|
-
|
|
5130
5299
|
# implements vector[idx] += scalar
|
|
5131
5300
|
add_builtin(
|
|
5132
|
-
"
|
|
5301
|
+
"add_inplace",
|
|
5133
5302
|
input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
|
|
5134
5303
|
value_type=None,
|
|
5135
5304
|
hidden=True,
|
|
@@ -5138,7 +5307,7 @@ add_builtin(
|
|
|
5138
5307
|
|
|
5139
5308
|
# implements quaternion[idx] += scalar
|
|
5140
5309
|
add_builtin(
|
|
5141
|
-
"
|
|
5310
|
+
"add_inplace",
|
|
5142
5311
|
input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
|
|
5143
5312
|
value_type=None,
|
|
5144
5313
|
hidden=True,
|
|
@@ -5147,7 +5316,7 @@ add_builtin(
|
|
|
5147
5316
|
|
|
5148
5317
|
# implements vector[idx] -= scalar
|
|
5149
5318
|
add_builtin(
|
|
5150
|
-
"
|
|
5319
|
+
"sub_inplace",
|
|
5151
5320
|
input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
|
|
5152
5321
|
value_type=None,
|
|
5153
5322
|
hidden=True,
|
|
@@ -5156,7 +5325,7 @@ add_builtin(
|
|
|
5156
5325
|
|
|
5157
5326
|
# implements quaternion[idx] -= scalar
|
|
5158
5327
|
add_builtin(
|
|
5159
|
-
"
|
|
5328
|
+
"sub_inplace",
|
|
5160
5329
|
input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
|
|
5161
5330
|
value_type=None,
|
|
5162
5331
|
hidden=True,
|
|
@@ -5200,11 +5369,6 @@ add_builtin(
|
|
|
5200
5369
|
)
|
|
5201
5370
|
|
|
5202
5371
|
|
|
5203
|
-
def matrix_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
5204
|
-
mat_type = arg_types["a"]
|
|
5205
|
-
return mat_type
|
|
5206
|
-
|
|
5207
|
-
|
|
5208
5372
|
def matrix_vector_sametype(arg_types: Mapping[str, Any]):
|
|
5209
5373
|
mat_size = arg_types["a"]._shape_[0]
|
|
5210
5374
|
vec_size = arg_types["value"]._length_
|
|
@@ -5215,7 +5379,33 @@ def matrix_vector_sametype(arg_types: Mapping[str, Any]):
|
|
|
5215
5379
|
|
|
5216
5380
|
# implements matrix[i,j] = scalar
|
|
5217
5381
|
add_builtin(
|
|
5218
|
-
"
|
|
5382
|
+
"assign_inplace",
|
|
5383
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5384
|
+
value_type=None,
|
|
5385
|
+
hidden=True,
|
|
5386
|
+
group="Utility",
|
|
5387
|
+
)
|
|
5388
|
+
|
|
5389
|
+
|
|
5390
|
+
# implements matrix[i] = vector
|
|
5391
|
+
add_builtin(
|
|
5392
|
+
"assign_inplace",
|
|
5393
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
|
|
5394
|
+
constraint=matrix_vector_sametype,
|
|
5395
|
+
value_type=None,
|
|
5396
|
+
hidden=True,
|
|
5397
|
+
group="Utility",
|
|
5398
|
+
)
|
|
5399
|
+
|
|
5400
|
+
|
|
5401
|
+
def matrix_assign_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
5402
|
+
mat_type = arg_types["a"]
|
|
5403
|
+
return mat_type
|
|
5404
|
+
|
|
5405
|
+
|
|
5406
|
+
# implements matrix[i,j] = scalar
|
|
5407
|
+
add_builtin(
|
|
5408
|
+
"assign_copy",
|
|
5219
5409
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5220
5410
|
value_func=matrix_assign_value_func,
|
|
5221
5411
|
hidden=True,
|
|
@@ -5225,7 +5415,7 @@ add_builtin(
|
|
|
5225
5415
|
|
|
5226
5416
|
# implements matrix[i] = vector
|
|
5227
5417
|
add_builtin(
|
|
5228
|
-
"
|
|
5418
|
+
"assign_copy",
|
|
5229
5419
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
|
|
5230
5420
|
constraint=matrix_vector_sametype,
|
|
5231
5421
|
value_func=matrix_assign_value_func,
|
|
@@ -5236,7 +5426,7 @@ add_builtin(
|
|
|
5236
5426
|
|
|
5237
5427
|
# implements matrix[i,j] += scalar
|
|
5238
5428
|
add_builtin(
|
|
5239
|
-
"
|
|
5429
|
+
"add_inplace",
|
|
5240
5430
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5241
5431
|
value_type=None,
|
|
5242
5432
|
hidden=True,
|
|
@@ -5244,9 +5434,20 @@ add_builtin(
|
|
|
5244
5434
|
)
|
|
5245
5435
|
|
|
5246
5436
|
|
|
5437
|
+
# implements matrix[i] += vector
|
|
5438
|
+
add_builtin(
|
|
5439
|
+
"add_inplace",
|
|
5440
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
|
|
5441
|
+
constraint=matrix_vector_sametype,
|
|
5442
|
+
value_type=None,
|
|
5443
|
+
hidden=True,
|
|
5444
|
+
group="Utility",
|
|
5445
|
+
)
|
|
5446
|
+
|
|
5447
|
+
|
|
5247
5448
|
# implements matrix[i,j] -= scalar
|
|
5248
5449
|
add_builtin(
|
|
5249
|
-
"
|
|
5450
|
+
"sub_inplace",
|
|
5250
5451
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5251
5452
|
value_type=None,
|
|
5252
5453
|
hidden=True,
|
|
@@ -5254,6 +5455,16 @@ add_builtin(
|
|
|
5254
5455
|
)
|
|
5255
5456
|
|
|
5256
5457
|
|
|
5458
|
+
# implements matrix[i] -= vector
|
|
5459
|
+
add_builtin(
|
|
5460
|
+
"sub_inplace",
|
|
5461
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
|
|
5462
|
+
value_type=None,
|
|
5463
|
+
hidden=True,
|
|
5464
|
+
group="Utility",
|
|
5465
|
+
)
|
|
5466
|
+
|
|
5467
|
+
|
|
5257
5468
|
for t in scalar_types + vector_types + (bool,):
|
|
5258
5469
|
if "vec" in t.__name__ or "mat" in t.__name__:
|
|
5259
5470
|
continue
|
|
@@ -5401,7 +5612,27 @@ add_builtin(
|
|
|
5401
5612
|
)
|
|
5402
5613
|
add_builtin(
|
|
5403
5614
|
"expect_near",
|
|
5404
|
-
input_types={"a":
|
|
5615
|
+
input_types={"a": vector(length=Any, dtype=Float), "b": vector(length=Any, dtype=Float), "tolerance": Float},
|
|
5616
|
+
defaults={"tolerance": 1.0e-6},
|
|
5617
|
+
value_type=None,
|
|
5618
|
+
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
5619
|
+
group="Utility",
|
|
5620
|
+
)
|
|
5621
|
+
add_builtin(
|
|
5622
|
+
"expect_near",
|
|
5623
|
+
input_types={"a": quaternion(dtype=Float), "b": quaternion(dtype=Float), "tolerance": Float},
|
|
5624
|
+
defaults={"tolerance": 1.0e-6},
|
|
5625
|
+
value_type=None,
|
|
5626
|
+
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
5627
|
+
group="Utility",
|
|
5628
|
+
)
|
|
5629
|
+
add_builtin(
|
|
5630
|
+
"expect_near",
|
|
5631
|
+
input_types={
|
|
5632
|
+
"a": matrix(shape=(Any, Any), dtype=Float),
|
|
5633
|
+
"b": matrix(shape=(Any, Any), dtype=Float),
|
|
5634
|
+
"tolerance": Float,
|
|
5635
|
+
},
|
|
5405
5636
|
defaults={"tolerance": 1.0e-6},
|
|
5406
5637
|
value_type=None,
|
|
5407
5638
|
doc="Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude",
|
|
@@ -5980,7 +6211,7 @@ add_builtin(
|
|
|
5980
6211
|
##
|
|
5981
6212
|
## Matmul
|
|
5982
6213
|
##
|
|
5983
|
-
def
|
|
6214
|
+
def tile_matmul_value_func(arg_types, arg_values):
|
|
5984
6215
|
# return generic type (for doc builds)
|
|
5985
6216
|
if arg_types is None:
|
|
5986
6217
|
return Tile(dtype=Any, shape=Any)
|
|
@@ -6006,7 +6237,7 @@ def tile_matmul_generic_value_func(arg_types, arg_values):
|
|
|
6006
6237
|
return None
|
|
6007
6238
|
|
|
6008
6239
|
|
|
6009
|
-
def
|
|
6240
|
+
def tile_matmul_lto_dispatch_func(
|
|
6010
6241
|
arg_types: Mapping[str, type],
|
|
6011
6242
|
return_type: Any,
|
|
6012
6243
|
return_values: List[Var],
|
|
@@ -6045,142 +6276,82 @@ def tile_matmul_generic_lto_dispatch_func(
|
|
|
6045
6276
|
out.type.storage = "shared"
|
|
6046
6277
|
template_args = [accumulate]
|
|
6047
6278
|
|
|
6048
|
-
# Maps Python/Warp types to C++ types and enums
|
|
6049
|
-
def cublasdx_type_map(dtype):
|
|
6050
|
-
if dtype == float16:
|
|
6051
|
-
return ("wp::float16", 3, 0)
|
|
6052
|
-
if dtype == float32:
|
|
6053
|
-
return ("wp::float32", 5, 0)
|
|
6054
|
-
if dtype == float64:
|
|
6055
|
-
return ("wp::float64", 6, 0)
|
|
6056
|
-
if dtype == vec2h:
|
|
6057
|
-
return ("wp::vec2h", 3, 1)
|
|
6058
|
-
if dtype == vec2f:
|
|
6059
|
-
return ("wp::vec2f", 5, 1)
|
|
6060
|
-
if dtype == vec2d:
|
|
6061
|
-
return ("wp::vec2d", 6, 1)
|
|
6062
|
-
raise TypeError("Unsupported input type in tile_matmul")
|
|
6063
|
-
|
|
6064
|
-
def cublasdx_arrangement_map(layout):
|
|
6065
|
-
if layout == "colmajor":
|
|
6066
|
-
return 0 # CUBLASDX_ARRANGEMENT_COL_MAJOR
|
|
6067
|
-
if layout == "rowmajor":
|
|
6068
|
-
return 1 # CUBLASDX_ARRANGEMENT_ROW_MAJOR
|
|
6069
|
-
raise ValueError("Unsupported layout in tile_matmul")
|
|
6070
|
-
|
|
6071
|
-
# generate the LTO
|
|
6072
6279
|
M, K = a.type.shape[0], a.type.shape[1]
|
|
6073
6280
|
_, N = b.type.shape[0], b.type.shape[1]
|
|
6074
6281
|
num_threads = options["block_dim"]
|
|
6075
6282
|
arch = options["output_arch"]
|
|
6076
6283
|
|
|
6077
|
-
|
|
6078
|
-
|
|
6079
|
-
(
|
|
6080
|
-
|
|
6081
|
-
a_arrangement = cublasdx_arrangement_map(alayout)
|
|
6082
|
-
b_arrangement = cublasdx_arrangement_map(blayout)
|
|
6083
|
-
c_arrangement = cublasdx_arrangement_map(clayout)
|
|
6084
|
-
|
|
6085
|
-
if a_type != b_type or a_type != c_type:
|
|
6086
|
-
raise TypeError("time_matmul(A, B, C) requires all inputs to be real or complex")
|
|
6087
|
-
|
|
6088
|
-
element_type = a_type
|
|
6089
|
-
|
|
6090
|
-
lto_symbol = f"dot_{M}_{N}_{K}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{c_arrangement}_{a_prec}_{b_prec}_{c_prec}_{element_type}"
|
|
6284
|
+
if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
|
|
6285
|
+
# CPU/no-MathDx dispatch
|
|
6286
|
+
return ((0, 0, 0, a, b, out), template_args, [], 0)
|
|
6287
|
+
else:
|
|
6091
6288
|
|
|
6092
|
-
|
|
6093
|
-
|
|
6094
|
-
|
|
6289
|
+
def tile_flip_layout(layout):
|
|
6290
|
+
if layout == "rowmajor":
|
|
6291
|
+
return "colmajor"
|
|
6292
|
+
elif layout == "colmajor":
|
|
6293
|
+
return "rowmajor"
|
|
6095
6294
|
|
|
6096
|
-
#
|
|
6097
|
-
|
|
6098
|
-
|
|
6099
|
-
|
|
6100
|
-
|
|
6101
|
-
|
|
6102
|
-
|
|
6103
|
-
|
|
6295
|
+
# generate the LTOs
|
|
6296
|
+
# C += A * B
|
|
6297
|
+
(fun_forward, lto_forward) = warp.build.build_lto_dot(
|
|
6298
|
+
M,
|
|
6299
|
+
N,
|
|
6300
|
+
K,
|
|
6301
|
+
a.type.dtype,
|
|
6302
|
+
b.type.dtype,
|
|
6303
|
+
out.type.dtype,
|
|
6304
|
+
a.type.layout,
|
|
6305
|
+
b.type.layout,
|
|
6306
|
+
out.type.layout,
|
|
6104
6307
|
arch,
|
|
6308
|
+
num_threads,
|
|
6309
|
+
builder,
|
|
6310
|
+
)
|
|
6311
|
+
# adjA += adjC * B^T - Transpose ~= flipped layout
|
|
6312
|
+
(fun_backward_A, lto_backward_A) = warp.build.build_lto_dot(
|
|
6105
6313
|
M,
|
|
6314
|
+
K,
|
|
6106
6315
|
N,
|
|
6316
|
+
out.type.dtype,
|
|
6317
|
+
b.type.dtype,
|
|
6318
|
+
a.type.dtype,
|
|
6319
|
+
out.type.layout,
|
|
6320
|
+
tile_flip_layout(b.type.layout),
|
|
6321
|
+
a.type.layout,
|
|
6322
|
+
arch,
|
|
6323
|
+
num_threads,
|
|
6324
|
+
builder,
|
|
6325
|
+
)
|
|
6326
|
+
# adjB += A^T * adjC - Transpose ~= flipped layout
|
|
6327
|
+
(fun_backward_B, lto_backward_B) = warp.build.build_lto_dot(
|
|
6107
6328
|
K,
|
|
6108
|
-
|
|
6109
|
-
|
|
6110
|
-
|
|
6111
|
-
|
|
6112
|
-
|
|
6113
|
-
|
|
6114
|
-
|
|
6329
|
+
N,
|
|
6330
|
+
M,
|
|
6331
|
+
a.type.dtype,
|
|
6332
|
+
out.type.dtype,
|
|
6333
|
+
b.type.dtype,
|
|
6334
|
+
tile_flip_layout(a.type.layout),
|
|
6335
|
+
out.type.layout,
|
|
6336
|
+
b.type.layout,
|
|
6337
|
+
arch,
|
|
6115
6338
|
num_threads,
|
|
6339
|
+
builder,
|
|
6116
6340
|
)
|
|
6117
|
-
lto_code_path = Path(lto_code.name)
|
|
6118
|
-
if not result:
|
|
6119
|
-
lto_code.close()
|
|
6120
|
-
if lto_code_path.exists():
|
|
6121
|
-
lto_code_path.unlink()
|
|
6122
|
-
raise RuntimeError("Failed to compile tile_matmul")
|
|
6123
|
-
else:
|
|
6124
|
-
with open(lto_code.name, "rb") as f:
|
|
6125
|
-
lto_code_data = f.read()
|
|
6126
|
-
lto_code.close()
|
|
6127
|
-
lto_code_path.unlink()
|
|
6128
|
-
|
|
6129
|
-
builder.ltoirs[lto_symbol] = lto_code_data
|
|
6130
|
-
builder.ltoirs_decl[lto_symbol] = (
|
|
6131
|
-
f"void {lto_symbol}({c_dtype}, {a_dtype}*, {b_dtype}*, {c_dtype}, {c_dtype}*);"
|
|
6132
|
-
)
|
|
6133
|
-
|
|
6134
|
-
return lto_symbol, lto_code_data
|
|
6135
6341
|
|
|
6136
|
-
|
|
6137
|
-
|
|
6138
|
-
|
|
6139
|
-
|
|
6140
|
-
|
|
6141
|
-
|
|
6142
|
-
|
|
6143
|
-
|
|
6144
|
-
|
|
6145
|
-
|
|
6146
|
-
|
|
6147
|
-
|
|
6148
|
-
|
|
6149
|
-
K,
|
|
6150
|
-
N,
|
|
6151
|
-
out.type.dtype,
|
|
6152
|
-
b.type.dtype,
|
|
6153
|
-
a.type.dtype,
|
|
6154
|
-
out.type.layout,
|
|
6155
|
-
tile_flip_layout(b.type.layout),
|
|
6156
|
-
a.type.layout,
|
|
6157
|
-
)
|
|
6158
|
-
# adjB += A^T * adjC - Transpose ~= flipped layout
|
|
6159
|
-
(fun_backward_B, lto_backward_B) = make_function(
|
|
6160
|
-
K,
|
|
6161
|
-
N,
|
|
6162
|
-
M,
|
|
6163
|
-
a.type.dtype,
|
|
6164
|
-
out.type.dtype,
|
|
6165
|
-
b.type.dtype,
|
|
6166
|
-
tile_flip_layout(a.type.layout),
|
|
6167
|
-
out.type.layout,
|
|
6168
|
-
b.type.layout,
|
|
6169
|
-
)
|
|
6170
|
-
|
|
6171
|
-
return (
|
|
6172
|
-
(
|
|
6173
|
-
Var(fun_forward, str, False, True, False),
|
|
6174
|
-
Var(fun_backward_A, str, False, True, False),
|
|
6175
|
-
Var(fun_backward_B, str, False, True, False),
|
|
6176
|
-
a,
|
|
6177
|
-
b,
|
|
6178
|
-
out,
|
|
6179
|
-
),
|
|
6180
|
-
template_args,
|
|
6181
|
-
[lto_forward, lto_backward_A, lto_backward_B],
|
|
6182
|
-
0,
|
|
6183
|
-
)
|
|
6342
|
+
return (
|
|
6343
|
+
(
|
|
6344
|
+
Var(fun_forward, str, False, True, False),
|
|
6345
|
+
Var(fun_backward_A, str, False, True, False),
|
|
6346
|
+
Var(fun_backward_B, str, False, True, False),
|
|
6347
|
+
a,
|
|
6348
|
+
b,
|
|
6349
|
+
out,
|
|
6350
|
+
),
|
|
6351
|
+
template_args,
|
|
6352
|
+
[lto_forward, lto_backward_A, lto_backward_B],
|
|
6353
|
+
0,
|
|
6354
|
+
)
|
|
6184
6355
|
|
|
6185
6356
|
|
|
6186
6357
|
add_builtin(
|
|
@@ -6190,8 +6361,8 @@ add_builtin(
|
|
|
6190
6361
|
"b": Tile(dtype=Any, shape=Any),
|
|
6191
6362
|
"out": Tile(dtype=Any, shape=Any),
|
|
6192
6363
|
},
|
|
6193
|
-
value_func=
|
|
6194
|
-
lto_dispatch_func=
|
|
6364
|
+
value_func=tile_matmul_value_func,
|
|
6365
|
+
lto_dispatch_func=tile_matmul_lto_dispatch_func,
|
|
6195
6366
|
variadic=False,
|
|
6196
6367
|
doc="""Computes the matrix product and accumulates ``out += a*b``.
|
|
6197
6368
|
|
|
@@ -6199,7 +6370,7 @@ add_builtin(
|
|
|
6199
6370
|
* fp16, fp32, fp64 (real)
|
|
6200
6371
|
* vec2h, vec2f, vec2d (complex)
|
|
6201
6372
|
|
|
6202
|
-
All input and output tiles must have the same datatype. Tile data will
|
|
6373
|
+
All input and output tiles must have the same datatype. Tile data will automatically be migrated
|
|
6203
6374
|
to shared memory if necessary and will use TensorCore operations when available.
|
|
6204
6375
|
|
|
6205
6376
|
:param a: A tile with ``shape=(M, K)``
|
|
@@ -6213,8 +6384,8 @@ add_builtin(
|
|
|
6213
6384
|
add_builtin(
|
|
6214
6385
|
"tile_matmul",
|
|
6215
6386
|
input_types={"a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
|
|
6216
|
-
value_func=
|
|
6217
|
-
lto_dispatch_func=
|
|
6387
|
+
value_func=tile_matmul_value_func,
|
|
6388
|
+
lto_dispatch_func=tile_matmul_lto_dispatch_func,
|
|
6218
6389
|
variadic=False,
|
|
6219
6390
|
doc="""Computes the matrix product ``out = a*b``.
|
|
6220
6391
|
|
|
@@ -6222,7 +6393,7 @@ add_builtin(
|
|
|
6222
6393
|
* fp16, fp32, fp64 (real)
|
|
6223
6394
|
* vec2h, vec2f, vec2d (complex)
|
|
6224
6395
|
|
|
6225
|
-
Both input tiles must have the same datatype. Tile data will
|
|
6396
|
+
Both input tiles must have the same datatype. Tile data will automatically be migrated
|
|
6226
6397
|
to shared memory if necessary and will use TensorCore operations when available.
|
|
6227
6398
|
|
|
6228
6399
|
:param a: A tile with ``shape=(M, K)``
|
|
@@ -6294,59 +6465,29 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
6294
6465
|
num_threads = options["block_dim"]
|
|
6295
6466
|
arch = options["output_arch"]
|
|
6296
6467
|
ept = size // num_threads
|
|
6297
|
-
|
|
6298
|
-
|
|
6299
|
-
|
|
6300
|
-
|
|
6301
|
-
|
|
6302
|
-
|
|
6303
|
-
|
|
6304
|
-
|
|
6305
|
-
|
|
6306
|
-
|
|
6307
|
-
|
|
6308
|
-
|
|
6309
|
-
|
|
6310
|
-
|
|
6311
|
-
|
|
6312
|
-
|
|
6313
|
-
|
|
6314
|
-
|
|
6315
|
-
|
|
6316
|
-
|
|
6317
|
-
|
|
6318
|
-
|
|
6319
|
-
|
|
6320
|
-
lto_code_path = Path(lto_code.name)
|
|
6321
|
-
if not result:
|
|
6322
|
-
lto_code.close()
|
|
6323
|
-
if lto_code_path.exists():
|
|
6324
|
-
lto_code_path.unlink()
|
|
6325
|
-
raise RuntimeError("Failed to compile tile_fft")
|
|
6326
|
-
|
|
6327
|
-
with open(lto_code.name, "rb") as f:
|
|
6328
|
-
lto_code_data = f.read()
|
|
6329
|
-
|
|
6330
|
-
lto_code.close()
|
|
6331
|
-
lto_code_path.unlink()
|
|
6332
|
-
|
|
6333
|
-
builder.ltoirs[lto_symbol] = lto_code_data
|
|
6334
|
-
|
|
6335
|
-
shared_memory_bytes = Tile.round_up(shared_memory_size.value)
|
|
6336
|
-
|
|
6337
|
-
return (
|
|
6338
|
-
(
|
|
6339
|
-
Var(lto_symbol, str, False, True, False),
|
|
6340
|
-
Var(dtype, str, False, True, False),
|
|
6341
|
-
Var(str(shared_memory_bytes), str, False, True, False),
|
|
6342
|
-
Var(str(batch), str, False, True, False),
|
|
6343
|
-
Var(str(ept), str, False, True, False),
|
|
6344
|
-
inout,
|
|
6345
|
-
),
|
|
6346
|
-
[],
|
|
6347
|
-
[lto_code_data],
|
|
6348
|
-
shared_memory_bytes,
|
|
6349
|
-
)
|
|
6468
|
+
|
|
6469
|
+
if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
|
|
6470
|
+
# CPU/no-MathDx dispatch
|
|
6471
|
+
return ([], [], [], 0)
|
|
6472
|
+
else:
|
|
6473
|
+
# generate the LTO
|
|
6474
|
+
lto_symbol, lto_code_data, shared_memory_bytes = warp.build.build_lto_fft(
|
|
6475
|
+
arch, size, ept, direction, dir, precision, builder
|
|
6476
|
+
)
|
|
6477
|
+
|
|
6478
|
+
return (
|
|
6479
|
+
(
|
|
6480
|
+
Var(lto_symbol, str, False, True, False),
|
|
6481
|
+
Var(dtype, str, False, True, False),
|
|
6482
|
+
Var(str(shared_memory_bytes), str, False, True, False),
|
|
6483
|
+
Var(str(batch), str, False, True, False),
|
|
6484
|
+
Var(str(ept), str, False, True, False),
|
|
6485
|
+
inout,
|
|
6486
|
+
),
|
|
6487
|
+
[],
|
|
6488
|
+
[lto_code_data],
|
|
6489
|
+
shared_memory_bytes,
|
|
6490
|
+
)
|
|
6350
6491
|
|
|
6351
6492
|
|
|
6352
6493
|
add_builtin(
|
|
@@ -6408,7 +6549,7 @@ def tile_cholesky_generic_value_func(arg_types, arg_values):
|
|
|
6408
6549
|
raise TypeError(f"tile_cholesky() argument must be a tile, got {a!r}")
|
|
6409
6550
|
|
|
6410
6551
|
if len(a.shape) != 2:
|
|
6411
|
-
raise ValueError("tile_cholesky()
|
|
6552
|
+
raise ValueError("tile_cholesky() argument must be a 2D tile")
|
|
6412
6553
|
|
|
6413
6554
|
if a.shape[0] != a.shape[1]:
|
|
6414
6555
|
raise ValueError("tile_cholesky() argument must be square")
|
|
@@ -6449,57 +6590,36 @@ def tile_cholesky_generic_lto_dispatch_func(
|
|
|
6449
6590
|
if out.type.shape[0] != M or out.type.shape[1] != M:
|
|
6450
6591
|
raise ValueError("tile_cholesky() output tile must be square")
|
|
6451
6592
|
|
|
6452
|
-
|
|
6453
|
-
|
|
6454
|
-
lto_symbol = f"potrf_{M}_{N}_{arch}_{precision_enum}"
|
|
6455
|
-
|
|
6456
|
-
# early out if LTO for this combination already exists for this module
|
|
6457
|
-
if lto_symbol in builder.ltoirs:
|
|
6458
|
-
return lto_symbol, builder.ltoirs[lto_symbol]
|
|
6459
|
-
|
|
6460
|
-
# otherwise compile LTO
|
|
6461
|
-
lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
6462
|
-
universal_fatbin_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
6593
|
+
solver = "potrf"
|
|
6594
|
+
solver_enum = cusolver_function_map[solver]
|
|
6463
6595
|
|
|
6464
|
-
# cuSOLVERDx only
|
|
6596
|
+
# cuSOLVERDx only supports col-major input/outputs,
|
|
6465
6597
|
# so we use upper to mimic a row-major input
|
|
6466
|
-
|
|
6467
|
-
universal_fatbin_code.name.encode("utf-8"),
|
|
6468
|
-
lto_code.name.encode("utf-8"),
|
|
6469
|
-
lto_symbol.encode("utf-8"),
|
|
6470
|
-
0,
|
|
6471
|
-
None,
|
|
6472
|
-
None,
|
|
6473
|
-
arch,
|
|
6474
|
-
M,
|
|
6475
|
-
N,
|
|
6476
|
-
cusolver_function_map["potrf"],
|
|
6477
|
-
precision_enum,
|
|
6478
|
-
cusolver_fill_mode_map["upper"],
|
|
6479
|
-
num_threads,
|
|
6480
|
-
)
|
|
6598
|
+
fill_mode = cusolver_fill_mode_map["upper"]
|
|
6481
6599
|
|
|
6482
|
-
|
|
6483
|
-
|
|
6484
|
-
|
|
6485
|
-
if Path(f.name).exists():
|
|
6486
|
-
Path(f.name).unlink()
|
|
6487
|
-
raise RuntimeError("Failed to compile tile_cholesky")
|
|
6600
|
+
arch = options["output_arch"]
|
|
6601
|
+
num_threads = options["block_dim"]
|
|
6602
|
+
parameter_list = f"({dtype}*, unsigned)"
|
|
6488
6603
|
|
|
6604
|
+
if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
|
|
6605
|
+
# CPU/no-MathDx dispatch
|
|
6606
|
+
return ((0, a, out), [], [], 0)
|
|
6489
6607
|
else:
|
|
6490
|
-
|
|
6491
|
-
|
|
6492
|
-
|
|
6493
|
-
|
|
6494
|
-
|
|
6495
|
-
|
|
6496
|
-
|
|
6497
|
-
|
|
6498
|
-
|
|
6499
|
-
|
|
6500
|
-
|
|
6608
|
+
# generate the LTO
|
|
6609
|
+
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
6610
|
+
M,
|
|
6611
|
+
N,
|
|
6612
|
+
solver,
|
|
6613
|
+
solver_enum,
|
|
6614
|
+
fill_mode,
|
|
6615
|
+
arch,
|
|
6616
|
+
precision_enum,
|
|
6617
|
+
num_threads,
|
|
6618
|
+
parameter_list,
|
|
6619
|
+
builder,
|
|
6620
|
+
)
|
|
6501
6621
|
|
|
6502
|
-
|
|
6622
|
+
return ((Var(lto_symbol, str, False, True, False), a, out), [], [lto_code_data], 0)
|
|
6503
6623
|
|
|
6504
6624
|
|
|
6505
6625
|
add_builtin(
|
|
@@ -6593,57 +6713,36 @@ def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
|
6593
6713
|
f"got {y.type.shape[0]} elements in output and {M} rows in 'L'"
|
|
6594
6714
|
)
|
|
6595
6715
|
|
|
6596
|
-
|
|
6597
|
-
|
|
6598
|
-
lto_symbol = f"potrs_{M}_{N}_{arch}_{precision_enum}"
|
|
6599
|
-
|
|
6600
|
-
# early out if LTO for this combination already exists for this module
|
|
6601
|
-
if lto_symbol in builder.ltoirs:
|
|
6602
|
-
return lto_symbol, builder.ltoirs[lto_symbol]
|
|
6603
|
-
|
|
6604
|
-
# otherwise compile LTO
|
|
6605
|
-
lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
6606
|
-
universal_fatbin_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
6716
|
+
solver = "potrs"
|
|
6717
|
+
solver_enum = cusolver_function_map[solver]
|
|
6607
6718
|
|
|
6608
|
-
# cuSOLVERDx only
|
|
6719
|
+
# cuSOLVERDx only supports col-major input/outputs,
|
|
6609
6720
|
# so we use upper to mimic a row-major input
|
|
6610
|
-
|
|
6611
|
-
universal_fatbin_code.name.encode("utf-8"),
|
|
6612
|
-
lto_code.name.encode("utf-8"),
|
|
6613
|
-
lto_symbol.encode("utf-8"),
|
|
6614
|
-
0,
|
|
6615
|
-
None,
|
|
6616
|
-
None,
|
|
6617
|
-
arch,
|
|
6618
|
-
M,
|
|
6619
|
-
N,
|
|
6620
|
-
cusolver_function_map["potrs"],
|
|
6621
|
-
precision_enum,
|
|
6622
|
-
cusolver_fill_mode_map["upper"],
|
|
6623
|
-
num_threads,
|
|
6624
|
-
)
|
|
6721
|
+
fill_mode = cusolver_fill_mode_map["upper"]
|
|
6625
6722
|
|
|
6626
|
-
|
|
6627
|
-
|
|
6628
|
-
|
|
6629
|
-
if Path(f.name).exists():
|
|
6630
|
-
Path(f.name).unlink()
|
|
6631
|
-
raise RuntimeError("Failed to compile tile_cholesky_solve")
|
|
6723
|
+
arch = options["output_arch"]
|
|
6724
|
+
num_threads = options["block_dim"]
|
|
6725
|
+
parameter_list = f"({dtype}*, {dtype}*)"
|
|
6632
6726
|
|
|
6727
|
+
if arch is None or not warp.context.runtime.core.is_mathdx_enabled():
|
|
6728
|
+
# CPU/no-MathDx dispatch
|
|
6729
|
+
return ((0, L, x, y), [], [], 0)
|
|
6633
6730
|
else:
|
|
6634
|
-
|
|
6635
|
-
|
|
6636
|
-
|
|
6637
|
-
|
|
6638
|
-
|
|
6639
|
-
|
|
6640
|
-
|
|
6641
|
-
|
|
6642
|
-
|
|
6643
|
-
|
|
6644
|
-
|
|
6645
|
-
|
|
6646
|
-
|
|
6731
|
+
# generate the LTO
|
|
6732
|
+
lto_symbol, lto_code_data = warp.build.build_lto_solver(
|
|
6733
|
+
M,
|
|
6734
|
+
N,
|
|
6735
|
+
solver,
|
|
6736
|
+
solver_enum,
|
|
6737
|
+
fill_mode,
|
|
6738
|
+
arch,
|
|
6739
|
+
precision_enum,
|
|
6740
|
+
num_threads,
|
|
6741
|
+
parameter_list,
|
|
6742
|
+
builder,
|
|
6743
|
+
)
|
|
6744
|
+
|
|
6745
|
+
return ((Var(lto_symbol, str, False, True, False), L, x, y), [], [lto_code_data], 0)
|
|
6647
6746
|
|
|
6648
6747
|
|
|
6649
6748
|
add_builtin(
|