warp-lang 1.6.1__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +21 -7
- warp/autograd.py +14 -6
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +424 -6
- warp/build_dll.py +20 -20
- warp/builtins.py +467 -368
- warp/codegen.py +193 -125
- warp/config.py +56 -12
- warp/constants.py +14 -6
- warp/context.py +524 -277
- warp/dlpack.py +22 -12
- warp/examples/__init__.py +14 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_api.py +14 -6
- warp/examples/benchmarks/benchmark_cloth.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
- warp/examples/benchmarks/benchmark_gemm.py +82 -48
- warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
- warp/examples/benchmarks/benchmark_launches.py +14 -6
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +14 -6
- warp/examples/core/example_cupy.py +14 -6
- warp/examples/core/example_dem.py +14 -6
- warp/examples/core/example_fluid.py +14 -6
- warp/examples/core/example_graph_capture.py +14 -6
- warp/examples/core/example_marching_cubes.py +14 -6
- warp/examples/core/example_mesh.py +14 -6
- warp/examples/core/example_mesh_intersect.py +14 -6
- warp/examples/core/example_nvdb.py +14 -6
- warp/examples/core/example_raycast.py +14 -6
- warp/examples/core/example_raymarch.py +14 -6
- warp/examples/core/example_render_opengl.py +14 -6
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +14 -6
- warp/examples/core/example_torch.py +14 -6
- warp/examples/core/example_wave.py +14 -6
- warp/examples/fem/example_adaptive_grid.py +14 -6
- warp/examples/fem/example_apic_fluid.py +15 -7
- warp/examples/fem/example_burgers.py +16 -8
- warp/examples/fem/example_convection_diffusion.py +14 -6
- warp/examples/fem/example_convection_diffusion_dg.py +14 -6
- warp/examples/fem/example_deformed_geometry.py +15 -7
- warp/examples/fem/example_diffusion.py +14 -6
- warp/examples/fem/example_diffusion_3d.py +14 -6
- warp/examples/fem/example_diffusion_mgpu.py +14 -6
- warp/examples/fem/example_distortion_energy.py +15 -7
- warp/examples/fem/example_magnetostatics.py +20 -12
- warp/examples/fem/example_mixed_elasticity.py +14 -6
- warp/examples/fem/example_navier_stokes.py +14 -6
- warp/examples/fem/example_nonconforming_contact.py +14 -6
- warp/examples/fem/example_stokes.py +14 -6
- warp/examples/fem/example_stokes_transfer.py +14 -6
- warp/examples/fem/example_streamlines.py +14 -6
- warp/examples/fem/utils.py +24 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +14 -6
- warp/examples/optim/example_cloth_throw.py +14 -6
- warp/examples/optim/example_diffray.py +14 -6
- warp/examples/optim/example_drone.py +14 -6
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +14 -6
- warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
- warp/examples/optim/example_softbody_properties.py +14 -6
- warp/examples/optim/example_spring_cage.py +14 -6
- warp/examples/optim/example_trajectory.py +14 -6
- warp/examples/sim/example_cartpole.py +14 -6
- warp/examples/sim/example_cloth.py +14 -6
- warp/examples/sim/example_cloth_self_contact.py +14 -6
- warp/examples/sim/example_granular.py +14 -6
- warp/examples/sim/example_granular_collision_sdf.py +14 -6
- warp/examples/sim/example_jacobian_ik.py +14 -6
- warp/examples/sim/example_particle_chain.py +14 -6
- warp/examples/sim/example_quadruped.py +14 -6
- warp/examples/sim/example_rigid_chain.py +14 -6
- warp/examples/sim/example_rigid_contact.py +14 -6
- warp/examples/sim/example_rigid_force.py +14 -6
- warp/examples/sim/example_rigid_gyroscopic.py +14 -6
- warp/examples/sim/example_rigid_soft_contact.py +14 -6
- warp/examples/sim/example_soft_body.py +14 -6
- warp/examples/tile/example_tile_cholesky.py +14 -6
- warp/examples/tile/example_tile_convolution.py +14 -6
- warp/examples/tile/example_tile_fft.py +14 -6
- warp/examples/tile/example_tile_filtering.py +14 -6
- warp/examples/tile/example_tile_matmul.py +16 -10
- warp/examples/tile/example_tile_mlp.py +14 -6
- warp/examples/tile/example_tile_nbody.py +14 -6
- warp/examples/tile/example_tile_walker.py +14 -6
- warp/fabric.py +15 -0
- warp/fem/__init__.py +26 -1
- warp/fem/adaptivity.py +19 -4
- warp/fem/cache.py +15 -0
- warp/fem/dirichlet.py +15 -0
- warp/fem/domain.py +15 -0
- warp/fem/field/__init__.py +15 -0
- warp/fem/field/field.py +15 -0
- warp/fem/field/nodal_field.py +37 -68
- warp/fem/field/restriction.py +15 -0
- warp/fem/field/virtual.py +77 -23
- warp/fem/geometry/__init__.py +15 -0
- warp/fem/geometry/adaptive_nanogrid.py +24 -10
- warp/fem/geometry/closest_point.py +16 -1
- warp/fem/geometry/deformed_geometry.py +20 -2
- warp/fem/geometry/element.py +15 -0
- warp/fem/geometry/geometry.py +20 -0
- warp/fem/geometry/grid_2d.py +27 -12
- warp/fem/geometry/grid_3d.py +27 -15
- warp/fem/geometry/hexmesh.py +20 -7
- warp/fem/geometry/nanogrid.py +24 -11
- warp/fem/geometry/partition.py +15 -0
- warp/fem/geometry/quadmesh.py +28 -13
- warp/fem/geometry/tetmesh.py +18 -4
- warp/fem/geometry/trimesh.py +18 -8
- warp/fem/integrate.py +277 -93
- warp/fem/linalg.py +20 -5
- warp/fem/operator.py +15 -0
- warp/fem/polynomial.py +15 -0
- warp/fem/quadrature/__init__.py +15 -0
- warp/fem/quadrature/pic_quadrature.py +52 -22
- warp/fem/quadrature/quadrature.py +209 -25
- warp/fem/space/__init__.py +16 -1
- warp/fem/space/basis_function_space.py +19 -2
- warp/fem/space/basis_space.py +40 -18
- warp/fem/space/dof_mapper.py +15 -0
- warp/fem/space/function_space.py +15 -0
- warp/fem/space/grid_2d_function_space.py +15 -0
- warp/fem/space/grid_3d_function_space.py +15 -0
- warp/fem/space/hexmesh_function_space.py +17 -2
- warp/fem/space/nanogrid_function_space.py +15 -0
- warp/fem/space/partition.py +21 -2
- warp/fem/space/quadmesh_function_space.py +23 -8
- warp/fem/space/restriction.py +15 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +38 -23
- warp/fem/space/shape/shape_function.py +15 -0
- warp/fem/space/shape/square_shape_function.py +27 -12
- warp/fem/space/shape/tet_shape_function.py +15 -0
- warp/fem/space/shape/triangle_shape_function.py +16 -1
- warp/fem/space/tetmesh_function_space.py +18 -3
- warp/fem/space/topology.py +15 -0
- warp/fem/space/trimesh_function_space.py +17 -2
- warp/fem/types.py +15 -0
- warp/fem/utils.py +27 -6
- warp/jax.py +28 -7
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +103 -6
- warp/native/array.h +28 -6
- warp/native/builtin.h +44 -9
- warp/native/bvh.cpp +18 -7
- warp/native/bvh.cu +57 -20
- warp/native/bvh.h +17 -7
- warp/native/clang/clang.cpp +45 -9
- warp/native/coloring.cpp +15 -6
- warp/native/crt.cpp +15 -6
- warp/native/crt.h +15 -6
- warp/native/cuda_crt.h +15 -6
- warp/native/cuda_util.cpp +29 -6
- warp/native/cuda_util.h +17 -6
- warp/native/error.cpp +15 -6
- warp/native/error.h +15 -6
- warp/native/exports.h +85 -63
- warp/native/fabric.h +15 -6
- warp/native/hashgrid.cpp +15 -6
- warp/native/hashgrid.cu +15 -6
- warp/native/hashgrid.h +15 -6
- warp/native/initializer_array.h +15 -6
- warp/native/intersect.h +41 -32
- warp/native/intersect_adj.h +48 -39
- warp/native/intersect_tri.h +17 -0
- warp/native/marching.cpp +16 -0
- warp/native/marching.cu +16 -7
- warp/native/marching.h +17 -0
- warp/native/mat.h +528 -15
- warp/native/mathdx.cpp +15 -6
- warp/native/matnn.h +15 -6
- warp/native/mesh.cpp +15 -6
- warp/native/mesh.cu +15 -6
- warp/native/mesh.h +25 -16
- warp/native/noise.h +15 -6
- warp/native/quat.h +114 -17
- warp/native/rand.h +21 -6
- warp/native/range.h +15 -6
- warp/native/reduce.cpp +15 -6
- warp/native/reduce.cu +15 -6
- warp/native/runlength_encode.cpp +15 -6
- warp/native/runlength_encode.cu +15 -6
- warp/native/scan.cpp +15 -6
- warp/native/scan.cu +15 -6
- warp/native/scan.h +15 -6
- warp/native/solid_angle.h +17 -0
- warp/native/sort.cpp +137 -65
- warp/native/sort.cu +167 -21
- warp/native/sort.h +23 -7
- warp/native/sparse.cpp +58 -28
- warp/native/sparse.cu +67 -23
- warp/native/spatial.h +15 -6
- warp/native/svd.h +131 -6
- warp/native/temp_buffer.h +15 -6
- warp/native/tile.h +316 -111
- warp/native/tile_reduce.h +61 -9
- warp/native/vec.h +83 -13
- warp/native/volume.cpp +100 -119
- warp/native/volume.cu +15 -6
- warp/native/volume.h +15 -6
- warp/native/volume_builder.cu +40 -16
- warp/native/volume_builder.h +21 -6
- warp/native/volume_impl.h +15 -6
- warp/native/warp.cpp +20 -12
- warp/native/warp.cu +114 -16
- warp/native/warp.h +34 -16
- warp/optim/__init__.py +14 -6
- warp/optim/adam.py +14 -6
- warp/optim/linear.py +25 -10
- warp/optim/sgd.py +14 -6
- warp/paddle.py +14 -6
- warp/render/__init__.py +14 -6
- warp/render/render_opengl.py +14 -6
- warp/render/render_usd.py +14 -6
- warp/render/utils.py +14 -6
- warp/sim/__init__.py +14 -7
- warp/sim/articulation.py +18 -10
- warp/sim/collide.py +35 -16
- warp/sim/graph_coloring.py +14 -6
- warp/sim/import_mjcf.py +463 -162
- warp/sim/import_snu.py +14 -7
- warp/sim/import_urdf.py +46 -18
- warp/sim/import_usd.py +14 -7
- warp/sim/inertia.py +14 -6
- warp/sim/integrator.py +14 -6
- warp/sim/integrator_euler.py +19 -11
- warp/sim/integrator_featherstone.py +17 -16
- warp/sim/integrator_vbd.py +222 -8
- warp/sim/integrator_xpbd.py +19 -11
- warp/sim/model.py +56 -19
- warp/sim/particles.py +14 -6
- warp/sim/render.py +14 -6
- warp/sim/utils.py +17 -2
- warp/sparse.py +657 -555
- warp/stubs.py +231 -19
- warp/tape.py +14 -6
- warp/tests/aux_test_class_kernel.py +14 -6
- warp/tests/aux_test_compile_consts_dummy.py +14 -6
- warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
- warp/tests/aux_test_dependent.py +14 -6
- warp/tests/aux_test_grad_customs.py +14 -6
- warp/tests/aux_test_instancing_gc.py +14 -6
- warp/tests/aux_test_module_unload.py +14 -6
- warp/tests/aux_test_name_clash1.py +14 -6
- warp/tests/aux_test_name_clash2.py +14 -6
- warp/tests/aux_test_unresolved_func.py +14 -6
- warp/tests/aux_test_unresolved_symbol.py +14 -6
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_async.py → cuda/test_async.py} +14 -6
- warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
- warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
- warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
- warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
- warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
- warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
- warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
- warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
- warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
- warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
- warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
- warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
- warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
- warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
- warp/tests/run_coverage_serial.py +14 -6
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
- warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
- warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
- warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
- warp/tests/{test_model.py → sim/test_model.py} +55 -7
- warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +14 -6
- warp/tests/test_arithmetic.py +14 -6
- warp/tests/test_array.py +14 -6
- warp/tests/test_array_reduce.py +14 -6
- warp/tests/test_assert.py +14 -6
- warp/tests/test_atomic.py +14 -6
- warp/tests/test_bool.py +15 -7
- warp/tests/test_builtins_resolution.py +14 -6
- warp/tests/test_closest_point_edge_edge.py +14 -6
- warp/tests/test_codegen.py +14 -6
- warp/tests/test_codegen_instancing.py +14 -6
- warp/tests/test_compile_consts.py +14 -6
- warp/tests/test_conditional.py +14 -6
- warp/tests/test_context.py +14 -6
- warp/tests/test_copy.py +14 -6
- warp/tests/test_ctypes.py +14 -6
- warp/tests/test_dense.py +14 -6
- warp/tests/test_devices.py +14 -6
- warp/tests/test_examples.py +42 -42
- warp/tests/test_fabricarray.py +14 -6
- warp/tests/test_fast_math.py +14 -6
- warp/tests/test_fem.py +37 -10
- warp/tests/test_fp16.py +14 -6
- warp/tests/test_func.py +14 -6
- warp/tests/test_future_annotations.py +14 -6
- warp/tests/test_generics.py +14 -6
- warp/tests/test_grad.py +14 -6
- warp/tests/test_grad_customs.py +14 -6
- warp/tests/test_grad_debug.py +14 -6
- warp/tests/test_implicit_init.py +14 -6
- warp/tests/test_import.py +14 -6
- warp/tests/test_indexedarray.py +14 -6
- warp/tests/test_intersect.py +14 -6
- warp/tests/test_iter.py +14 -6
- warp/tests/test_large.py +14 -6
- warp/tests/test_launch.py +14 -6
- warp/tests/test_lerp.py +14 -6
- warp/tests/test_linear_solvers.py +15 -11
- warp/tests/test_lvalue.py +14 -6
- warp/tests/test_mat.py +247 -85
- warp/tests/test_mat_lite.py +14 -6
- warp/tests/test_mat_scalar_ops.py +18 -10
- warp/tests/test_math.py +14 -6
- warp/tests/test_mlp.py +14 -6
- warp/tests/test_module_hashing.py +14 -6
- warp/tests/test_modules_lite.py +14 -6
- warp/tests/test_noise.py +14 -6
- warp/tests/test_operators.py +14 -6
- warp/tests/test_options.py +14 -6
- warp/tests/test_overwrite.py +15 -60
- warp/tests/test_print.py +14 -6
- warp/tests/test_quat.py +81 -52
- warp/tests/test_rand.py +58 -43
- warp/tests/test_reload.py +14 -6
- warp/tests/test_rounding.py +14 -6
- warp/tests/test_runlength_encode.py +14 -6
- warp/tests/test_scalar_ops.py +14 -6
- warp/tests/test_smoothstep.py +14 -6
- warp/tests/test_snippet.py +15 -0
- warp/tests/test_sparse.py +61 -12
- warp/tests/test_spatial.py +89 -6
- warp/tests/test_special_values.py +14 -6
- warp/tests/test_static.py +15 -7
- warp/tests/test_struct.py +14 -6
- warp/tests/test_tape.py +14 -6
- warp/tests/test_transient_module.py +14 -6
- warp/tests/test_triangle_closest_point.py +14 -6
- warp/tests/test_types.py +14 -6
- warp/tests/test_utils.py +98 -10
- warp/tests/test_vec.py +60 -40
- warp/tests/test_vec_lite.py +14 -6
- warp/tests/test_vec_scalar_ops.py +14 -6
- warp/tests/test_verify_fp.py +14 -6
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
- warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
- warp/tests/unittest_serial.py +15 -6
- warp/tests/unittest_suites.py +59 -65
- warp/tests/unittest_utils.py +16 -7
- warp/tests/walkthrough_debug.py +14 -6
- warp/thirdparty/unittest_parallel.py +15 -8
- warp/torch.py +14 -6
- warp/types.py +124 -664
- warp/utils.py +151 -78
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp/examples/optim/example_walker.py +0 -309
- warp/native/cutlass_gemm.cpp +0 -34
- warp/native/cutlass_gemm.cu +0 -373
- warp/tests/test_matmul.py +0 -503
- warp/tests/test_matmul_lite.py +0 -403
- warp/tests/test_vbd.py +0 -378
- warp/tests/unused_test_misc.py +0 -69
- warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
- warp_lang-1.6.1.dist-info/RECORD +0 -419
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/types.py
CHANGED
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2022 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
from __future__ import annotations
|
|
9
17
|
|
|
@@ -12,7 +20,21 @@ import ctypes
|
|
|
12
20
|
import inspect
|
|
13
21
|
import struct
|
|
14
22
|
import zlib
|
|
15
|
-
from typing import
|
|
23
|
+
from typing import (
|
|
24
|
+
Any,
|
|
25
|
+
Callable,
|
|
26
|
+
Generic,
|
|
27
|
+
List,
|
|
28
|
+
Literal,
|
|
29
|
+
NamedTuple,
|
|
30
|
+
Optional,
|
|
31
|
+
Sequence,
|
|
32
|
+
Tuple,
|
|
33
|
+
TypeVar,
|
|
34
|
+
Union,
|
|
35
|
+
get_args,
|
|
36
|
+
get_origin,
|
|
37
|
+
)
|
|
16
38
|
|
|
17
39
|
import numpy as np
|
|
18
40
|
import numpy.typing as npt
|
|
@@ -48,7 +70,9 @@ class Transformation(Generic[Float]):
|
|
|
48
70
|
|
|
49
71
|
|
|
50
72
|
class Array(Generic[DType]):
|
|
51
|
-
|
|
73
|
+
device: Optional[warp.context.Device]
|
|
74
|
+
dtype: type
|
|
75
|
+
size: int
|
|
52
76
|
|
|
53
77
|
|
|
54
78
|
int_tuple_type_hints = {
|
|
@@ -1131,7 +1155,7 @@ ARRAY_TYPE_FABRIC_INDEXED = 3
|
|
|
1131
1155
|
class launch_bounds_t(ctypes.Structure):
|
|
1132
1156
|
_fields_ = [("shape", ctypes.c_int32 * LAUNCH_MAX_DIMS), ("ndim", ctypes.c_int32), ("size", ctypes.c_size_t)]
|
|
1133
1157
|
|
|
1134
|
-
def __init__(self, shape):
|
|
1158
|
+
def __init__(self, shape: Union[int, Sequence[int]]):
|
|
1135
1159
|
if isinstance(shape, int):
|
|
1136
1160
|
# 1d launch
|
|
1137
1161
|
self.ndim = 1
|
|
@@ -1252,7 +1276,7 @@ _type_size_cache = {
|
|
|
1252
1276
|
}
|
|
1253
1277
|
|
|
1254
1278
|
|
|
1255
|
-
def type_size_in_bytes(dtype):
|
|
1279
|
+
def type_size_in_bytes(dtype: type) -> int:
|
|
1256
1280
|
size = _type_size_cache.get(dtype)
|
|
1257
1281
|
|
|
1258
1282
|
if size is None:
|
|
@@ -1271,7 +1295,7 @@ def type_size_in_bytes(dtype):
|
|
|
1271
1295
|
return size
|
|
1272
1296
|
|
|
1273
1297
|
|
|
1274
|
-
def type_to_warp(dtype):
|
|
1298
|
+
def type_to_warp(dtype: type) -> type:
|
|
1275
1299
|
if dtype == float:
|
|
1276
1300
|
return float32
|
|
1277
1301
|
elif dtype == int:
|
|
@@ -1282,7 +1306,7 @@ def type_to_warp(dtype):
|
|
|
1282
1306
|
return dtype
|
|
1283
1307
|
|
|
1284
1308
|
|
|
1285
|
-
def type_typestr(dtype):
|
|
1309
|
+
def type_typestr(dtype: type) -> str:
|
|
1286
1310
|
if dtype == bool:
|
|
1287
1311
|
return "|b1"
|
|
1288
1312
|
elif dtype == float16:
|
|
@@ -1368,29 +1392,29 @@ def type_is_transformation(t):
|
|
|
1368
1392
|
return getattr(t, "_wp_generic_type_hint_", None) is Transformation
|
|
1369
1393
|
|
|
1370
1394
|
|
|
1371
|
-
value_types = (int, float, builtins.bool) +
|
|
1395
|
+
value_types = (int, float, builtins.bool) + scalar_and_bool_types
|
|
1372
1396
|
|
|
1373
1397
|
|
|
1374
1398
|
# returns true for all value types (int, float, bool, scalars, vectors, matrices)
|
|
1375
|
-
def type_is_value(x):
|
|
1399
|
+
def type_is_value(x: Any) -> builtins.bool:
|
|
1376
1400
|
return x in value_types or hasattr(x, "_wp_scalar_type_")
|
|
1377
1401
|
|
|
1378
1402
|
|
|
1379
1403
|
# equivalent of the above but for values
|
|
1380
|
-
def is_int(x):
|
|
1404
|
+
def is_int(x: Any) -> builtins.bool:
|
|
1381
1405
|
return type_is_int(type(x))
|
|
1382
1406
|
|
|
1383
1407
|
|
|
1384
|
-
def is_float(x):
|
|
1408
|
+
def is_float(x: Any) -> builtins.bool:
|
|
1385
1409
|
return type_is_float(type(x))
|
|
1386
1410
|
|
|
1387
1411
|
|
|
1388
|
-
def is_value(x):
|
|
1412
|
+
def is_value(x: Any) -> builtins.bool:
|
|
1389
1413
|
return type_is_value(type(x))
|
|
1390
1414
|
|
|
1391
1415
|
|
|
1392
|
-
|
|
1393
|
-
|
|
1416
|
+
def is_array(a) -> builtins.bool:
|
|
1417
|
+
"""Return true if the passed *instance* is one of the array types."""
|
|
1394
1418
|
return isinstance(a, array_types)
|
|
1395
1419
|
|
|
1396
1420
|
|
|
@@ -1457,21 +1481,21 @@ def types_equal(a, b, match_generic=False):
|
|
|
1457
1481
|
if a_length is None or b_length is None or a_length == b_length:
|
|
1458
1482
|
return True
|
|
1459
1483
|
|
|
1460
|
-
a_origin =
|
|
1461
|
-
b_origin =
|
|
1484
|
+
a_origin = get_origin(a)
|
|
1485
|
+
b_origin = get_origin(b)
|
|
1462
1486
|
if a_origin is tuple and b_origin is tuple:
|
|
1463
|
-
a_args =
|
|
1464
|
-
b_args =
|
|
1487
|
+
a_args = get_args(a)
|
|
1488
|
+
b_args = get_args(b)
|
|
1465
1489
|
if len(a_args) == len(b_args) and all(
|
|
1466
1490
|
scalars_equal(x, y, match_generic=match_generic) for x, y in zip(a_args, b_args)
|
|
1467
1491
|
):
|
|
1468
1492
|
return True
|
|
1469
1493
|
elif a_origin is tuple and isinstance(b, Sequence):
|
|
1470
|
-
a_args =
|
|
1494
|
+
a_args = get_args(a)
|
|
1471
1495
|
if len(a_args) == len(b) and all(scalars_equal(x, y, match_generic=match_generic) for x, y in zip(a_args, b)):
|
|
1472
1496
|
return True
|
|
1473
1497
|
elif b_origin is tuple and isinstance(a, Sequence):
|
|
1474
|
-
b_args =
|
|
1498
|
+
b_args = get_args(b)
|
|
1475
1499
|
if len(b_args) == len(a) and all(scalars_equal(x, y, match_generic=match_generic) for x, y in zip(b_args, a)):
|
|
1476
1500
|
return True
|
|
1477
1501
|
|
|
@@ -1592,7 +1616,7 @@ def array_ctype_from_interface(interface: dict, dtype=None, owner=None):
|
|
|
1592
1616
|
return array_ctype
|
|
1593
1617
|
|
|
1594
1618
|
|
|
1595
|
-
class array(Array):
|
|
1619
|
+
class array(Array[DType]):
|
|
1596
1620
|
"""A fixed-size multi-dimensional array containing values of the same type.
|
|
1597
1621
|
|
|
1598
1622
|
Attributes:
|
|
@@ -1621,21 +1645,21 @@ class array(Array):
|
|
|
1621
1645
|
|
|
1622
1646
|
def __init__(
|
|
1623
1647
|
self,
|
|
1624
|
-
data:
|
|
1625
|
-
dtype:
|
|
1626
|
-
shape:
|
|
1648
|
+
data: Union[List, Tuple, npt.NDArray, None] = None,
|
|
1649
|
+
dtype: Any = Any,
|
|
1650
|
+
shape: Union[int, Tuple[int, ...], List[int], None] = None,
|
|
1627
1651
|
strides: Optional[Tuple[int, ...]] = None,
|
|
1628
1652
|
length: Optional[int] = None,
|
|
1629
1653
|
ptr: Optional[int] = None,
|
|
1630
1654
|
capacity: Optional[int] = None,
|
|
1631
1655
|
device=None,
|
|
1632
|
-
pinned: bool = False,
|
|
1633
|
-
copy: bool = True,
|
|
1634
|
-
owner: bool = False, # deprecated - pass deleter instead
|
|
1656
|
+
pinned: builtins.bool = False,
|
|
1657
|
+
copy: builtins.bool = True,
|
|
1658
|
+
owner: builtins.bool = False, # deprecated - pass deleter instead
|
|
1635
1659
|
deleter: Optional[Callable[[int, int], None]] = None,
|
|
1636
1660
|
ndim: Optional[int] = None,
|
|
1637
1661
|
grad: Optional[array] = None,
|
|
1638
|
-
requires_grad: bool = False,
|
|
1662
|
+
requires_grad: builtins.bool = False,
|
|
1639
1663
|
):
|
|
1640
1664
|
"""Constructs a new Warp array object
|
|
1641
1665
|
|
|
@@ -2931,7 +2955,7 @@ def from_ipc_handle(
|
|
|
2931
2955
|
|
|
2932
2956
|
# A base class for non-contiguous arrays, providing the implementation of common methods like
|
|
2933
2957
|
# contiguous(), to(), numpy(), list(), assign(), zero_(), and fill_().
|
|
2934
|
-
class noncontiguous_array_base(
|
|
2958
|
+
class noncontiguous_array_base(Array[T]):
|
|
2935
2959
|
def __init__(self, array_type_id):
|
|
2936
2960
|
self.type_id = array_type_id
|
|
2937
2961
|
self.is_contiguous = False
|
|
@@ -3028,12 +3052,18 @@ def check_index_array(indices, expected_device):
|
|
|
3028
3052
|
raise ValueError(f"Index array device ({indices.device} does not match data array device ({expected_device}))")
|
|
3029
3053
|
|
|
3030
3054
|
|
|
3031
|
-
class indexedarray(noncontiguous_array_base
|
|
3055
|
+
class indexedarray(noncontiguous_array_base):
|
|
3032
3056
|
# member attributes available during code-gen (e.g.: d = arr.shape[0])
|
|
3033
3057
|
# (initialized when needed)
|
|
3034
3058
|
_vars = None
|
|
3035
3059
|
|
|
3036
|
-
def __init__(
|
|
3060
|
+
def __init__(
|
|
3061
|
+
self,
|
|
3062
|
+
data: Optional[array] = None,
|
|
3063
|
+
indices: Union[array, List[array], None] = None,
|
|
3064
|
+
dtype=None,
|
|
3065
|
+
ndim: Optional[int] = None,
|
|
3066
|
+
):
|
|
3037
3067
|
super().__init__(ARRAY_TYPE_INDEXED)
|
|
3038
3068
|
|
|
3039
3069
|
# canonicalize types
|
|
@@ -3224,7 +3254,7 @@ class Tile:
|
|
|
3224
3254
|
return f"wp::tile_alloc_empty<{Var.type_to_ctype(self.dtype)},wp::tile_shape_t<{','.join(map(str, self.shape))}>,{'true' if requires_grad else 'false'}>()"
|
|
3225
3255
|
else:
|
|
3226
3256
|
# tile will be initialized by another call, e.g.: tile_transpose()
|
|
3227
|
-
return "
|
|
3257
|
+
return "nullptr"
|
|
3228
3258
|
|
|
3229
3259
|
# return total tile size in bytes
|
|
3230
3260
|
def size_in_bytes(self):
|
|
@@ -3626,7 +3656,7 @@ class Volume:
|
|
|
3626
3656
|
instance.id = None
|
|
3627
3657
|
return instance
|
|
3628
3658
|
|
|
3629
|
-
def __init__(self, data: array, copy: bool = True):
|
|
3659
|
+
def __init__(self, data: array, copy: builtins.bool = True):
|
|
3630
3660
|
"""Class representing a sparse grid.
|
|
3631
3661
|
|
|
3632
3662
|
Args:
|
|
@@ -4353,6 +4383,15 @@ class Volume:
|
|
|
4353
4383
|
translation_buf = (ctypes.c_float * 3)(translation[0], translation[1], translation[2])
|
|
4354
4384
|
return transform_buf, translation_buf
|
|
4355
4385
|
|
|
4386
|
+
# nanovdb types for which we instantiate the grid builder
|
|
4387
|
+
# Should be in sync with WP_VOLUME_BUILDER_INSTANTIATE_TYPES in volume_builder.h
|
|
4388
|
+
_supported_allocation_types = [
|
|
4389
|
+
"int32",
|
|
4390
|
+
"float",
|
|
4391
|
+
"Vec3f",
|
|
4392
|
+
"Vec4f",
|
|
4393
|
+
]
|
|
4394
|
+
|
|
4356
4395
|
@classmethod
|
|
4357
4396
|
def allocate_by_tiles(
|
|
4358
4397
|
cls,
|
|
@@ -4380,7 +4419,8 @@ class Volume:
|
|
|
4380
4419
|
or a floating point scalar type (2D N-by-3 array of :class:`warp.float32` or 1D array of `warp.vec3f` values), indicating world space positions.
|
|
4381
4420
|
Repeated points per tile are allowed and will be efficiently deduplicated.
|
|
4382
4421
|
voxel_size (float or array-like): Voxel size(s) of the new volume. Ignored if `transform` is given.
|
|
4383
|
-
bg_value (array-like,
|
|
4422
|
+
bg_value (array-like, scalar or None): Value of unallocated voxels of the volume, also defines the volume's type. An index volume will be created if `bg_value` is ``None``.
|
|
4423
|
+
Other supported grid types are `int`, `float`, `vec3f`, and `vec4f`.
|
|
4384
4424
|
translation (array-like): Translation between the index and world spaces.
|
|
4385
4425
|
transform (array-like): Linear transform between the index and world spaces. If ``None``, deduced from `voxel_size`.
|
|
4386
4426
|
device (Devicelike): The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
|
|
@@ -4412,35 +4452,47 @@ class Volume:
|
|
|
4412
4452
|
translation_buf,
|
|
4413
4453
|
in_world_space,
|
|
4414
4454
|
)
|
|
4415
|
-
elif hasattr(bg_value, "__len__"):
|
|
4416
|
-
volume.id = volume.runtime.core.volume_v_from_tiles_device(
|
|
4417
|
-
volume.device.context,
|
|
4418
|
-
ctypes.c_void_p(tile_points.ptr),
|
|
4419
|
-
tile_points.shape[0],
|
|
4420
|
-
transform_buf,
|
|
4421
|
-
translation_buf,
|
|
4422
|
-
in_world_space,
|
|
4423
|
-
(ctypes.c_float * 3)(bg_value[0], bg_value[1], bg_value[2]),
|
|
4424
|
-
)
|
|
4425
|
-
elif isinstance(bg_value, int):
|
|
4426
|
-
volume.id = volume.runtime.core.volume_i_from_tiles_device(
|
|
4427
|
-
volume.device.context,
|
|
4428
|
-
ctypes.c_void_p(tile_points.ptr),
|
|
4429
|
-
tile_points.shape[0],
|
|
4430
|
-
transform_buf,
|
|
4431
|
-
translation_buf,
|
|
4432
|
-
in_world_space,
|
|
4433
|
-
bg_value,
|
|
4434
|
-
)
|
|
4435
4455
|
else:
|
|
4436
|
-
|
|
4456
|
+
# normalize background value type
|
|
4457
|
+
grid_type = type_to_warp(type(bg_value))
|
|
4458
|
+
if not (is_value(bg_value) or type_is_vector(grid_type)) and (
|
|
4459
|
+
hasattr(bg_value, "__len__") and is_value(bg_value[0])
|
|
4460
|
+
):
|
|
4461
|
+
# non-warp vectors are considered float, for backward compatibility
|
|
4462
|
+
grid_type = vector(len(bg_value), dtype=float)
|
|
4463
|
+
|
|
4464
|
+
# look for corresponding nvdb type
|
|
4465
|
+
try:
|
|
4466
|
+
nvdb_type = next(
|
|
4467
|
+
typ
|
|
4468
|
+
for typ in Volume._supported_allocation_types
|
|
4469
|
+
if types_equal(grid_type, Volume._nvdb_type_to_dtype[typ])
|
|
4470
|
+
)
|
|
4471
|
+
except StopIteration as err:
|
|
4472
|
+
raise TypeError(
|
|
4473
|
+
f"Unsupported bg_value type for volume allocation {type_repr(grid_type)}. Supported volume types are {', '.join(Volume._supported_allocation_types)}."
|
|
4474
|
+
) from err
|
|
4475
|
+
|
|
4476
|
+
# cast to ctype
|
|
4477
|
+
# wrap scalar values in length-1 vectors to handle specific ctype conversion
|
|
4478
|
+
if not type_is_vector(grid_type):
|
|
4479
|
+
grid_type = vector(length=1, dtype=grid_type)
|
|
4480
|
+
|
|
4481
|
+
cvalue = grid_type(bg_value)
|
|
4482
|
+
cvalue_ptr = ctypes.pointer(cvalue)
|
|
4483
|
+
cvalue_size = ctypes.sizeof(cvalue)
|
|
4484
|
+
cvalue_type = nvdb_type.encode("ascii")
|
|
4485
|
+
|
|
4486
|
+
volume.id = volume.runtime.core.volume_from_tiles_device(
|
|
4437
4487
|
volume.device.context,
|
|
4438
4488
|
ctypes.c_void_p(tile_points.ptr),
|
|
4439
4489
|
tile_points.shape[0],
|
|
4440
4490
|
transform_buf,
|
|
4441
4491
|
translation_buf,
|
|
4442
4492
|
in_world_space,
|
|
4443
|
-
|
|
4493
|
+
cvalue_ptr,
|
|
4494
|
+
cvalue_size,
|
|
4495
|
+
cvalue_type,
|
|
4444
4496
|
)
|
|
4445
4497
|
|
|
4446
4498
|
if volume.id == 0:
|
|
@@ -4598,6 +4650,8 @@ def matmul(
|
|
|
4598
4650
|
):
|
|
4599
4651
|
"""Computes a generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
|
|
4600
4652
|
|
|
4653
|
+
.. versionremoved:: 1.7
|
|
4654
|
+
|
|
4601
4655
|
.. deprecated:: 1.6
|
|
4602
4656
|
Use :doc:`tile primitives </modules/tiles>` instead.
|
|
4603
4657
|
|
|
@@ -4611,80 +4665,8 @@ def matmul(
|
|
|
4611
4665
|
allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
|
|
4612
4666
|
while using Tensor Cores
|
|
4613
4667
|
"""
|
|
4614
|
-
from warp.context import runtime
|
|
4615
|
-
|
|
4616
|
-
warp.utils.warn(
|
|
4617
|
-
"wp.matmul() is deprecated and will be removed in a\nfuture version. Use tile primitives instead.",
|
|
4618
|
-
category=DeprecationWarning,
|
|
4619
|
-
stacklevel=2,
|
|
4620
|
-
)
|
|
4621
|
-
|
|
4622
|
-
device = a.device
|
|
4623
|
-
|
|
4624
|
-
if b.device != device or c.device != device or d.device != device:
|
|
4625
|
-
raise RuntimeError("Matrices A, B, C, and D must all be on the same device as the runtime device.")
|
|
4626
|
-
|
|
4627
|
-
if a.dtype != b.dtype or a.dtype != c.dtype or a.dtype != d.dtype:
|
|
4628
|
-
raise RuntimeError(
|
|
4629
|
-
"wp.matmul currently only supports operation between {A, B, C, D} matrices of the same type."
|
|
4630
|
-
)
|
|
4631
|
-
|
|
4632
|
-
if (
|
|
4633
|
-
(not a.is_contiguous and not a.is_transposed)
|
|
4634
|
-
or (not b.is_contiguous and not b.is_transposed)
|
|
4635
|
-
or (not c.is_contiguous)
|
|
4636
|
-
or (not d.is_contiguous)
|
|
4637
|
-
):
|
|
4638
|
-
raise RuntimeError(
|
|
4639
|
-
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
|
|
4640
|
-
)
|
|
4641
4668
|
|
|
4642
|
-
|
|
4643
|
-
n = b.shape[1]
|
|
4644
|
-
k = a.shape[1]
|
|
4645
|
-
if b.shape != (k, n) or c.shape != (m, n) or d.shape != (m, n):
|
|
4646
|
-
raise RuntimeError(
|
|
4647
|
-
"Invalid shapes for matrices: A = {} B = {} C = {} D = {}".format(a.shape, b.shape, c.shape, d.shape)
|
|
4648
|
-
)
|
|
4649
|
-
|
|
4650
|
-
if runtime.tape:
|
|
4651
|
-
runtime.tape.record_func(
|
|
4652
|
-
backward=lambda: adj_matmul(a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith),
|
|
4653
|
-
arrays=[a, b, c, d],
|
|
4654
|
-
)
|
|
4655
|
-
if warp.config.verify_autograd_array_access:
|
|
4656
|
-
d.mark_write()
|
|
4657
|
-
a.mark_read()
|
|
4658
|
-
b.mark_read()
|
|
4659
|
-
c.mark_read()
|
|
4660
|
-
|
|
4661
|
-
# cpu fallback if no cuda devices found
|
|
4662
|
-
if device == "cpu":
|
|
4663
|
-
np_dtype = warp_type_to_np_dtype[a.dtype]
|
|
4664
|
-
d.assign(alpha * np.matmul(a.numpy(), b.numpy(), dtype=np_dtype) + beta * c.numpy())
|
|
4665
|
-
return
|
|
4666
|
-
|
|
4667
|
-
cc = device.arch
|
|
4668
|
-
ret = runtime.core.cutlass_gemm(
|
|
4669
|
-
device.context,
|
|
4670
|
-
cc,
|
|
4671
|
-
m,
|
|
4672
|
-
n,
|
|
4673
|
-
k,
|
|
4674
|
-
type_typestr(a.dtype).encode(),
|
|
4675
|
-
ctypes.c_void_p(a.ptr),
|
|
4676
|
-
ctypes.c_void_p(b.ptr),
|
|
4677
|
-
ctypes.c_void_p(c.ptr),
|
|
4678
|
-
ctypes.c_void_p(d.ptr),
|
|
4679
|
-
alpha,
|
|
4680
|
-
beta,
|
|
4681
|
-
not a.is_transposed,
|
|
4682
|
-
not b.is_transposed,
|
|
4683
|
-
allow_tf32x3_arith,
|
|
4684
|
-
1,
|
|
4685
|
-
)
|
|
4686
|
-
if not ret:
|
|
4687
|
-
raise RuntimeError("matmul failed.")
|
|
4669
|
+
raise RuntimeError("This function has been removed. Use tile primitives instead.")
|
|
4688
4670
|
|
|
4689
4671
|
|
|
4690
4672
|
def adj_matmul(
|
|
@@ -4716,171 +4698,8 @@ def adj_matmul(
|
|
|
4716
4698
|
allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
|
|
4717
4699
|
while using Tensor Cores
|
|
4718
4700
|
"""
|
|
4719
|
-
from warp.context import runtime
|
|
4720
|
-
|
|
4721
|
-
device = a.device
|
|
4722
|
-
|
|
4723
|
-
if (
|
|
4724
|
-
b.device != device
|
|
4725
|
-
or c.device != device
|
|
4726
|
-
or adj_a.device != device
|
|
4727
|
-
or adj_b.device != device
|
|
4728
|
-
or adj_c.device != device
|
|
4729
|
-
or adj_d.device != device
|
|
4730
|
-
):
|
|
4731
|
-
raise RuntimeError(
|
|
4732
|
-
"Matrices A, B, C, D, and their adjoints must all be on the same device as the runtime device."
|
|
4733
|
-
)
|
|
4734
|
-
|
|
4735
|
-
if (
|
|
4736
|
-
a.dtype != b.dtype
|
|
4737
|
-
or a.dtype != c.dtype
|
|
4738
|
-
or a.dtype != adj_a.dtype
|
|
4739
|
-
or a.dtype != adj_b.dtype
|
|
4740
|
-
or a.dtype != adj_c.dtype
|
|
4741
|
-
or a.dtype != adj_d.dtype
|
|
4742
|
-
):
|
|
4743
|
-
raise RuntimeError(
|
|
4744
|
-
"wp.adj_matmul currently only supports operation between {A, B, C, adj_D, adj_A, adj_B, adj_C} matrices of the same type."
|
|
4745
|
-
)
|
|
4746
|
-
|
|
4747
|
-
if (
|
|
4748
|
-
(not a.is_contiguous and not a.is_transposed)
|
|
4749
|
-
or (not b.is_contiguous and not b.is_transposed)
|
|
4750
|
-
or (not c.is_contiguous)
|
|
4751
|
-
or (not adj_a.is_contiguous and not adj_a.is_transposed)
|
|
4752
|
-
or (not adj_b.is_contiguous and not adj_b.is_transposed)
|
|
4753
|
-
or (not adj_c.is_contiguous)
|
|
4754
|
-
or (not adj_d.is_contiguous)
|
|
4755
|
-
):
|
|
4756
|
-
raise RuntimeError(
|
|
4757
|
-
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
|
|
4758
|
-
)
|
|
4759
4701
|
|
|
4760
|
-
|
|
4761
|
-
n = b.shape[1]
|
|
4762
|
-
k = a.shape[1]
|
|
4763
|
-
if (
|
|
4764
|
-
a.shape != (m, k)
|
|
4765
|
-
or b.shape != (k, n)
|
|
4766
|
-
or c.shape != (m, n)
|
|
4767
|
-
or adj_d.shape != (m, n)
|
|
4768
|
-
or adj_a.shape != (m, k)
|
|
4769
|
-
or adj_b.shape != (k, n)
|
|
4770
|
-
or adj_c.shape != (m, n)
|
|
4771
|
-
):
|
|
4772
|
-
raise RuntimeError(
|
|
4773
|
-
"Invalid shapes for matrices: A = {} B = {} C = {} adj_D = {} adj_A = {} adj_B = {} adj_C = {}".format(
|
|
4774
|
-
a.shape, b.shape, c.shape, adj_d.shape, adj_a.shape, adj_b.shape, adj_c.shape
|
|
4775
|
-
)
|
|
4776
|
-
)
|
|
4777
|
-
|
|
4778
|
-
# cpu fallback if no cuda devices found
|
|
4779
|
-
if device == "cpu":
|
|
4780
|
-
np_dtype = warp_type_to_np_dtype[a.dtype]
|
|
4781
|
-
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose(), dtype=np_dtype) + adj_a.numpy())
|
|
4782
|
-
adj_b.assign(alpha * np.matmul(a.numpy().transpose(), adj_d.numpy(), dtype=np_dtype) + adj_b.numpy())
|
|
4783
|
-
adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
|
|
4784
|
-
return
|
|
4785
|
-
|
|
4786
|
-
cc = device.arch
|
|
4787
|
-
|
|
4788
|
-
# adj_a
|
|
4789
|
-
if not a.is_transposed:
|
|
4790
|
-
ret = runtime.core.cutlass_gemm(
|
|
4791
|
-
device.context,
|
|
4792
|
-
cc,
|
|
4793
|
-
m,
|
|
4794
|
-
k,
|
|
4795
|
-
n,
|
|
4796
|
-
type_typestr(a.dtype).encode(),
|
|
4797
|
-
ctypes.c_void_p(adj_d.ptr),
|
|
4798
|
-
ctypes.c_void_p(b.ptr),
|
|
4799
|
-
ctypes.c_void_p(adj_a.ptr),
|
|
4800
|
-
ctypes.c_void_p(adj_a.ptr),
|
|
4801
|
-
alpha,
|
|
4802
|
-
1.0,
|
|
4803
|
-
True,
|
|
4804
|
-
b.is_transposed,
|
|
4805
|
-
allow_tf32x3_arith,
|
|
4806
|
-
1,
|
|
4807
|
-
)
|
|
4808
|
-
if not ret:
|
|
4809
|
-
raise RuntimeError("adj_matmul failed.")
|
|
4810
|
-
else:
|
|
4811
|
-
ret = runtime.core.cutlass_gemm(
|
|
4812
|
-
device.context,
|
|
4813
|
-
cc,
|
|
4814
|
-
k,
|
|
4815
|
-
m,
|
|
4816
|
-
n,
|
|
4817
|
-
type_typestr(a.dtype).encode(),
|
|
4818
|
-
ctypes.c_void_p(b.ptr),
|
|
4819
|
-
ctypes.c_void_p(adj_d.ptr),
|
|
4820
|
-
ctypes.c_void_p(adj_a.ptr),
|
|
4821
|
-
ctypes.c_void_p(adj_a.ptr),
|
|
4822
|
-
alpha,
|
|
4823
|
-
1.0,
|
|
4824
|
-
not b.is_transposed,
|
|
4825
|
-
False,
|
|
4826
|
-
allow_tf32x3_arith,
|
|
4827
|
-
1,
|
|
4828
|
-
)
|
|
4829
|
-
if not ret:
|
|
4830
|
-
raise RuntimeError("adj_matmul failed.")
|
|
4831
|
-
|
|
4832
|
-
# adj_b
|
|
4833
|
-
if not b.is_transposed:
|
|
4834
|
-
ret = runtime.core.cutlass_gemm(
|
|
4835
|
-
device.context,
|
|
4836
|
-
cc,
|
|
4837
|
-
k,
|
|
4838
|
-
n,
|
|
4839
|
-
m,
|
|
4840
|
-
type_typestr(a.dtype).encode(),
|
|
4841
|
-
ctypes.c_void_p(a.ptr),
|
|
4842
|
-
ctypes.c_void_p(adj_d.ptr),
|
|
4843
|
-
ctypes.c_void_p(adj_b.ptr),
|
|
4844
|
-
ctypes.c_void_p(adj_b.ptr),
|
|
4845
|
-
alpha,
|
|
4846
|
-
1.0,
|
|
4847
|
-
a.is_transposed,
|
|
4848
|
-
True,
|
|
4849
|
-
allow_tf32x3_arith,
|
|
4850
|
-
1,
|
|
4851
|
-
)
|
|
4852
|
-
if not ret:
|
|
4853
|
-
raise RuntimeError("adj_matmul failed.")
|
|
4854
|
-
else:
|
|
4855
|
-
ret = runtime.core.cutlass_gemm(
|
|
4856
|
-
device.context,
|
|
4857
|
-
cc,
|
|
4858
|
-
n,
|
|
4859
|
-
k,
|
|
4860
|
-
m,
|
|
4861
|
-
type_typestr(a.dtype).encode(),
|
|
4862
|
-
ctypes.c_void_p(adj_d.ptr),
|
|
4863
|
-
ctypes.c_void_p(a.ptr),
|
|
4864
|
-
ctypes.c_void_p(adj_b.ptr),
|
|
4865
|
-
ctypes.c_void_p(adj_b.ptr),
|
|
4866
|
-
alpha,
|
|
4867
|
-
1.0,
|
|
4868
|
-
False,
|
|
4869
|
-
not a.is_transposed,
|
|
4870
|
-
allow_tf32x3_arith,
|
|
4871
|
-
1,
|
|
4872
|
-
)
|
|
4873
|
-
if not ret:
|
|
4874
|
-
raise RuntimeError("adj_matmul failed.")
|
|
4875
|
-
|
|
4876
|
-
# adj_c
|
|
4877
|
-
warp.launch(
|
|
4878
|
-
kernel=warp.utils.add_kernel_2d,
|
|
4879
|
-
dim=adj_c.shape,
|
|
4880
|
-
inputs=[adj_c, adj_d, adj_d.dtype(beta)],
|
|
4881
|
-
device=device,
|
|
4882
|
-
record_tape=False,
|
|
4883
|
-
)
|
|
4702
|
+
raise RuntimeError("This function has been removed. Use tile primitives instead.")
|
|
4884
4703
|
|
|
4885
4704
|
|
|
4886
4705
|
def batched_matmul(
|
|
@@ -4894,6 +4713,8 @@ def batched_matmul(
|
|
|
4894
4713
|
):
|
|
4895
4714
|
"""Computes a batched generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
|
|
4896
4715
|
|
|
4716
|
+
.. versionremoved:: 1.7
|
|
4717
|
+
|
|
4897
4718
|
.. deprecated:: 1.6
|
|
4898
4719
|
Use :doc:`tile primitives </modules/tiles>` instead.
|
|
4899
4720
|
|
|
@@ -4907,107 +4728,8 @@ def batched_matmul(
|
|
|
4907
4728
|
allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
|
|
4908
4729
|
while using Tensor Cores
|
|
4909
4730
|
"""
|
|
4910
|
-
from warp.context import runtime
|
|
4911
|
-
|
|
4912
|
-
device = a.device
|
|
4913
|
-
|
|
4914
|
-
if b.device != device or c.device != device or d.device != device:
|
|
4915
|
-
raise RuntimeError("Matrices A, B, C, and D must all be on the same device as the runtime device.")
|
|
4916
|
-
|
|
4917
|
-
if a.dtype != b.dtype or a.dtype != c.dtype or a.dtype != d.dtype:
|
|
4918
|
-
raise RuntimeError(
|
|
4919
|
-
"wp.batched_matmul currently only supports operation between {A, B, C, D} matrices of the same type."
|
|
4920
|
-
)
|
|
4921
|
-
|
|
4922
|
-
if (
|
|
4923
|
-
(not a.is_contiguous and not a.is_transposed)
|
|
4924
|
-
or (not b.is_contiguous and not b.is_transposed)
|
|
4925
|
-
or (not c.is_contiguous)
|
|
4926
|
-
or (not d.is_contiguous)
|
|
4927
|
-
):
|
|
4928
|
-
raise RuntimeError(
|
|
4929
|
-
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
|
|
4930
|
-
)
|
|
4931
|
-
|
|
4932
|
-
m = a.shape[1]
|
|
4933
|
-
n = b.shape[2]
|
|
4934
|
-
k = a.shape[2]
|
|
4935
|
-
batch_count = a.shape[0]
|
|
4936
|
-
if b.shape != (batch_count, k, n) or c.shape != (batch_count, m, n) or d.shape != (batch_count, m, n):
|
|
4937
|
-
raise RuntimeError(
|
|
4938
|
-
"Invalid shapes for matrices: A = {} B = {} C = {} D = {}".format(a.shape, b.shape, c.shape, d.shape)
|
|
4939
|
-
)
|
|
4940
4731
|
|
|
4941
|
-
|
|
4942
|
-
runtime.tape.record_func(
|
|
4943
|
-
backward=lambda: adj_batched_matmul(
|
|
4944
|
-
a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith
|
|
4945
|
-
),
|
|
4946
|
-
arrays=[a, b, c, d],
|
|
4947
|
-
)
|
|
4948
|
-
if warp.config.verify_autograd_array_access:
|
|
4949
|
-
d.mark_write()
|
|
4950
|
-
a.mark_read()
|
|
4951
|
-
b.mark_read()
|
|
4952
|
-
c.mark_read()
|
|
4953
|
-
|
|
4954
|
-
# cpu fallback if no cuda devices found
|
|
4955
|
-
if device == "cpu":
|
|
4956
|
-
np_dtype = warp_type_to_np_dtype[a.dtype]
|
|
4957
|
-
d.assign(alpha * np.matmul(a.numpy(), b.numpy(), dtype=np_dtype) + beta * c.numpy())
|
|
4958
|
-
return
|
|
4959
|
-
|
|
4960
|
-
# handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
|
|
4961
|
-
max_batch_count = 65535
|
|
4962
|
-
iters = int(batch_count / max_batch_count)
|
|
4963
|
-
remainder = batch_count % max_batch_count
|
|
4964
|
-
|
|
4965
|
-
cc = device.arch
|
|
4966
|
-
for i in range(iters):
|
|
4967
|
-
idx_start = i * max_batch_count
|
|
4968
|
-
idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
|
|
4969
|
-
ret = runtime.core.cutlass_gemm(
|
|
4970
|
-
device.context,
|
|
4971
|
-
cc,
|
|
4972
|
-
m,
|
|
4973
|
-
n,
|
|
4974
|
-
k,
|
|
4975
|
-
type_typestr(a.dtype).encode(),
|
|
4976
|
-
ctypes.c_void_p(a[idx_start:idx_end, :, :].ptr),
|
|
4977
|
-
ctypes.c_void_p(b[idx_start:idx_end, :, :].ptr),
|
|
4978
|
-
ctypes.c_void_p(c[idx_start:idx_end, :, :].ptr),
|
|
4979
|
-
ctypes.c_void_p(d[idx_start:idx_end, :, :].ptr),
|
|
4980
|
-
alpha,
|
|
4981
|
-
beta,
|
|
4982
|
-
not a.is_transposed,
|
|
4983
|
-
not b.is_transposed,
|
|
4984
|
-
allow_tf32x3_arith,
|
|
4985
|
-
max_batch_count,
|
|
4986
|
-
)
|
|
4987
|
-
if not ret:
|
|
4988
|
-
raise RuntimeError("Batched matmul failed.")
|
|
4989
|
-
|
|
4990
|
-
idx_start = iters * max_batch_count
|
|
4991
|
-
ret = runtime.core.cutlass_gemm(
|
|
4992
|
-
device.context,
|
|
4993
|
-
cc,
|
|
4994
|
-
m,
|
|
4995
|
-
n,
|
|
4996
|
-
k,
|
|
4997
|
-
type_typestr(a.dtype).encode(),
|
|
4998
|
-
ctypes.c_void_p(a[idx_start:, :, :].ptr),
|
|
4999
|
-
ctypes.c_void_p(b[idx_start:, :, :].ptr),
|
|
5000
|
-
ctypes.c_void_p(c[idx_start:, :, :].ptr),
|
|
5001
|
-
ctypes.c_void_p(d[idx_start:, :, :].ptr),
|
|
5002
|
-
alpha,
|
|
5003
|
-
beta,
|
|
5004
|
-
not a.is_transposed,
|
|
5005
|
-
not b.is_transposed,
|
|
5006
|
-
allow_tf32x3_arith,
|
|
5007
|
-
remainder,
|
|
5008
|
-
)
|
|
5009
|
-
if not ret:
|
|
5010
|
-
raise RuntimeError("Batched matmul failed.")
|
|
4732
|
+
raise RuntimeError("This function has been removed. Use tile primitives instead.")
|
|
5011
4733
|
|
|
5012
4734
|
|
|
5013
4735
|
def adj_batched_matmul(
|
|
@@ -5037,270 +4759,8 @@ def adj_batched_matmul(
|
|
|
5037
4759
|
allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
|
|
5038
4760
|
while using Tensor Cores
|
|
5039
4761
|
"""
|
|
5040
|
-
from warp.context import runtime
|
|
5041
4762
|
|
|
5042
|
-
|
|
5043
|
-
|
|
5044
|
-
if (
|
|
5045
|
-
b.device != device
|
|
5046
|
-
or c.device != device
|
|
5047
|
-
or adj_a.device != device
|
|
5048
|
-
or adj_b.device != device
|
|
5049
|
-
or adj_c.device != device
|
|
5050
|
-
or adj_d.device != device
|
|
5051
|
-
):
|
|
5052
|
-
raise RuntimeError(
|
|
5053
|
-
"Matrices A, B, C, D, and their adjoints must all be on the same device as the runtime device."
|
|
5054
|
-
)
|
|
5055
|
-
|
|
5056
|
-
if (
|
|
5057
|
-
a.dtype != b.dtype
|
|
5058
|
-
or a.dtype != c.dtype
|
|
5059
|
-
or a.dtype != adj_a.dtype
|
|
5060
|
-
or a.dtype != adj_b.dtype
|
|
5061
|
-
or a.dtype != adj_c.dtype
|
|
5062
|
-
or a.dtype != adj_d.dtype
|
|
5063
|
-
):
|
|
5064
|
-
raise RuntimeError(
|
|
5065
|
-
"wp.adj_batched_matmul currently only supports operation between {A, B, C, adj_D, adj_A, adj_B, adj_C} matrices of the same type."
|
|
5066
|
-
)
|
|
5067
|
-
|
|
5068
|
-
m = a.shape[1]
|
|
5069
|
-
n = b.shape[2]
|
|
5070
|
-
k = a.shape[2]
|
|
5071
|
-
batch_count = a.shape[0]
|
|
5072
|
-
if (
|
|
5073
|
-
b.shape != (batch_count, k, n)
|
|
5074
|
-
or c.shape != (batch_count, m, n)
|
|
5075
|
-
or adj_d.shape != (batch_count, m, n)
|
|
5076
|
-
or adj_a.shape != (batch_count, m, k)
|
|
5077
|
-
or adj_b.shape != (batch_count, k, n)
|
|
5078
|
-
or adj_c.shape != (batch_count, m, n)
|
|
5079
|
-
):
|
|
5080
|
-
raise RuntimeError(
|
|
5081
|
-
"Invalid shapes for matrices: A = {} B = {} C = {} adj_D = {} adj_A = {} adj_B = {} adj_C = {}".format(
|
|
5082
|
-
a.shape, b.shape, c.shape, adj_d.shape, adj_a.shape, adj_b.shape, adj_c.shape
|
|
5083
|
-
)
|
|
5084
|
-
)
|
|
5085
|
-
|
|
5086
|
-
if (
|
|
5087
|
-
(not a.is_contiguous and not a.is_transposed)
|
|
5088
|
-
or (not b.is_contiguous and not b.is_transposed)
|
|
5089
|
-
or (not c.is_contiguous)
|
|
5090
|
-
or (not adj_a.is_contiguous and not adj_a.is_transposed)
|
|
5091
|
-
or (not adj_b.is_contiguous and not adj_b.is_transposed)
|
|
5092
|
-
or (not adj_c.is_contiguous)
|
|
5093
|
-
or (not adj_d.is_contiguous)
|
|
5094
|
-
):
|
|
5095
|
-
raise RuntimeError(
|
|
5096
|
-
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
|
|
5097
|
-
)
|
|
5098
|
-
|
|
5099
|
-
# cpu fallback if no cuda devices found
|
|
5100
|
-
if device == "cpu":
|
|
5101
|
-
np_dtype = warp_type_to_np_dtype[a.dtype]
|
|
5102
|
-
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1)), dtype=np_dtype) + adj_a.numpy())
|
|
5103
|
-
adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy(), dtype=np_dtype) + adj_b.numpy())
|
|
5104
|
-
adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
|
|
5105
|
-
return
|
|
5106
|
-
|
|
5107
|
-
# handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
|
|
5108
|
-
max_batch_count = 65535
|
|
5109
|
-
iters = int(batch_count / max_batch_count)
|
|
5110
|
-
remainder = batch_count % max_batch_count
|
|
5111
|
-
|
|
5112
|
-
cc = device.arch
|
|
5113
|
-
|
|
5114
|
-
for i in range(iters):
|
|
5115
|
-
idx_start = i * max_batch_count
|
|
5116
|
-
idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
|
|
5117
|
-
|
|
5118
|
-
# adj_a
|
|
5119
|
-
if not a.is_transposed:
|
|
5120
|
-
ret = runtime.core.cutlass_gemm(
|
|
5121
|
-
device.context,
|
|
5122
|
-
cc,
|
|
5123
|
-
m,
|
|
5124
|
-
k,
|
|
5125
|
-
n,
|
|
5126
|
-
type_typestr(a.dtype).encode(),
|
|
5127
|
-
ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
|
|
5128
|
-
ctypes.c_void_p(b[idx_start:idx_end, :, :].ptr),
|
|
5129
|
-
ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
|
|
5130
|
-
ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
|
|
5131
|
-
alpha,
|
|
5132
|
-
1.0,
|
|
5133
|
-
True,
|
|
5134
|
-
b.is_transposed,
|
|
5135
|
-
allow_tf32x3_arith,
|
|
5136
|
-
max_batch_count,
|
|
5137
|
-
)
|
|
5138
|
-
if not ret:
|
|
5139
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5140
|
-
else:
|
|
5141
|
-
ret = runtime.core.cutlass_gemm(
|
|
5142
|
-
device.context,
|
|
5143
|
-
cc,
|
|
5144
|
-
k,
|
|
5145
|
-
m,
|
|
5146
|
-
n,
|
|
5147
|
-
type_typestr(a.dtype).encode(),
|
|
5148
|
-
ctypes.c_void_p(b[idx_start:idx_end, :, :].ptr),
|
|
5149
|
-
ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
|
|
5150
|
-
ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
|
|
5151
|
-
ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
|
|
5152
|
-
alpha,
|
|
5153
|
-
1.0,
|
|
5154
|
-
not b.is_transposed,
|
|
5155
|
-
False,
|
|
5156
|
-
allow_tf32x3_arith,
|
|
5157
|
-
max_batch_count,
|
|
5158
|
-
)
|
|
5159
|
-
if not ret:
|
|
5160
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5161
|
-
|
|
5162
|
-
# adj_b
|
|
5163
|
-
if not b.is_transposed:
|
|
5164
|
-
ret = runtime.core.cutlass_gemm(
|
|
5165
|
-
device.context,
|
|
5166
|
-
cc,
|
|
5167
|
-
k,
|
|
5168
|
-
n,
|
|
5169
|
-
m,
|
|
5170
|
-
type_typestr(a.dtype).encode(),
|
|
5171
|
-
ctypes.c_void_p(a[idx_start:idx_end, :, :].ptr),
|
|
5172
|
-
ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
|
|
5173
|
-
ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
|
|
5174
|
-
ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
|
|
5175
|
-
alpha,
|
|
5176
|
-
1.0,
|
|
5177
|
-
a.is_transposed,
|
|
5178
|
-
True,
|
|
5179
|
-
allow_tf32x3_arith,
|
|
5180
|
-
max_batch_count,
|
|
5181
|
-
)
|
|
5182
|
-
if not ret:
|
|
5183
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5184
|
-
else:
|
|
5185
|
-
ret = runtime.core.cutlass_gemm(
|
|
5186
|
-
device.context,
|
|
5187
|
-
cc,
|
|
5188
|
-
n,
|
|
5189
|
-
k,
|
|
5190
|
-
m,
|
|
5191
|
-
type_typestr(a.dtype).encode(),
|
|
5192
|
-
ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
|
|
5193
|
-
ctypes.c_void_p(a[idx_start:idx_end, :, :].ptr),
|
|
5194
|
-
ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
|
|
5195
|
-
ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
|
|
5196
|
-
alpha,
|
|
5197
|
-
1.0,
|
|
5198
|
-
False,
|
|
5199
|
-
not a.is_transposed,
|
|
5200
|
-
allow_tf32x3_arith,
|
|
5201
|
-
max_batch_count,
|
|
5202
|
-
)
|
|
5203
|
-
if not ret:
|
|
5204
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5205
|
-
|
|
5206
|
-
idx_start = iters * max_batch_count
|
|
5207
|
-
|
|
5208
|
-
# adj_a
|
|
5209
|
-
if not a.is_transposed:
|
|
5210
|
-
ret = runtime.core.cutlass_gemm(
|
|
5211
|
-
device.context,
|
|
5212
|
-
cc,
|
|
5213
|
-
m,
|
|
5214
|
-
k,
|
|
5215
|
-
n,
|
|
5216
|
-
type_typestr(a.dtype).encode(),
|
|
5217
|
-
ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
|
|
5218
|
-
ctypes.c_void_p(b[idx_start:, :, :].ptr),
|
|
5219
|
-
ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
|
|
5220
|
-
ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
|
|
5221
|
-
alpha,
|
|
5222
|
-
1.0,
|
|
5223
|
-
True,
|
|
5224
|
-
b.is_transposed,
|
|
5225
|
-
allow_tf32x3_arith,
|
|
5226
|
-
remainder,
|
|
5227
|
-
)
|
|
5228
|
-
if not ret:
|
|
5229
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5230
|
-
else:
|
|
5231
|
-
ret = runtime.core.cutlass_gemm(
|
|
5232
|
-
device.context,
|
|
5233
|
-
cc,
|
|
5234
|
-
k,
|
|
5235
|
-
m,
|
|
5236
|
-
n,
|
|
5237
|
-
type_typestr(a.dtype).encode(),
|
|
5238
|
-
ctypes.c_void_p(b[idx_start:, :, :].ptr),
|
|
5239
|
-
ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
|
|
5240
|
-
ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
|
|
5241
|
-
ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
|
|
5242
|
-
alpha,
|
|
5243
|
-
1.0,
|
|
5244
|
-
not b.is_transposed,
|
|
5245
|
-
False,
|
|
5246
|
-
allow_tf32x3_arith,
|
|
5247
|
-
remainder,
|
|
5248
|
-
)
|
|
5249
|
-
if not ret:
|
|
5250
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5251
|
-
|
|
5252
|
-
# adj_b
|
|
5253
|
-
if not b.is_transposed:
|
|
5254
|
-
ret = runtime.core.cutlass_gemm(
|
|
5255
|
-
device.context,
|
|
5256
|
-
cc,
|
|
5257
|
-
k,
|
|
5258
|
-
n,
|
|
5259
|
-
m,
|
|
5260
|
-
type_typestr(a.dtype).encode(),
|
|
5261
|
-
ctypes.c_void_p(a[idx_start:, :, :].ptr),
|
|
5262
|
-
ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
|
|
5263
|
-
ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
|
|
5264
|
-
ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
|
|
5265
|
-
alpha,
|
|
5266
|
-
1.0,
|
|
5267
|
-
a.is_transposed,
|
|
5268
|
-
True,
|
|
5269
|
-
allow_tf32x3_arith,
|
|
5270
|
-
remainder,
|
|
5271
|
-
)
|
|
5272
|
-
if not ret:
|
|
5273
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5274
|
-
else:
|
|
5275
|
-
ret = runtime.core.cutlass_gemm(
|
|
5276
|
-
device.context,
|
|
5277
|
-
cc,
|
|
5278
|
-
n,
|
|
5279
|
-
k,
|
|
5280
|
-
m,
|
|
5281
|
-
type_typestr(a.dtype).encode(),
|
|
5282
|
-
ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
|
|
5283
|
-
ctypes.c_void_p(a[idx_start:, :, :].ptr),
|
|
5284
|
-
ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
|
|
5285
|
-
ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
|
|
5286
|
-
alpha,
|
|
5287
|
-
1.0,
|
|
5288
|
-
False,
|
|
5289
|
-
not a.is_transposed,
|
|
5290
|
-
allow_tf32x3_arith,
|
|
5291
|
-
remainder,
|
|
5292
|
-
)
|
|
5293
|
-
if not ret:
|
|
5294
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5295
|
-
|
|
5296
|
-
# adj_c
|
|
5297
|
-
warp.launch(
|
|
5298
|
-
kernel=warp.utils.add_kernel_3d,
|
|
5299
|
-
dim=adj_c.shape,
|
|
5300
|
-
inputs=[adj_c, adj_d, adj_d.dtype(beta)],
|
|
5301
|
-
device=device,
|
|
5302
|
-
record_tape=False,
|
|
5303
|
-
)
|
|
4763
|
+
raise RuntimeError("This function has been removed. Use tile primitives instead.")
|
|
5304
4764
|
|
|
5305
4765
|
|
|
5306
4766
|
class HashGrid:
|
|
@@ -5683,7 +5143,7 @@ simple_type_codes = {
|
|
|
5683
5143
|
}
|
|
5684
5144
|
|
|
5685
5145
|
|
|
5686
|
-
def get_type_code(arg_type):
|
|
5146
|
+
def get_type_code(arg_type: type) -> str:
|
|
5687
5147
|
if arg_type == Any:
|
|
5688
5148
|
# special case for generics
|
|
5689
5149
|
# note: since Python 3.11 Any is a type, so we check for it first
|
|
@@ -5747,8 +5207,8 @@ def get_type_code(arg_type):
|
|
|
5747
5207
|
raise TypeError(f"Unrecognized type '{arg_type}'")
|
|
5748
5208
|
|
|
5749
5209
|
|
|
5750
|
-
def get_signature(arg_types, func_name=None, arg_names=None):
|
|
5751
|
-
type_codes = []
|
|
5210
|
+
def get_signature(arg_types: List[type], func_name: Optional[str] = None, arg_names: Optional[List[str]] = None) -> str:
|
|
5211
|
+
type_codes: List[str] = []
|
|
5752
5212
|
for i, arg_type in enumerate(arg_types):
|
|
5753
5213
|
try:
|
|
5754
5214
|
type_codes.append(get_type_code(arg_type))
|