warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +139 -0
- warp/__init__.pyi +1 -0
- warp/autograd.py +1142 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +557 -0
- warp/build_dll.py +405 -0
- warp/builtins.py +6855 -0
- warp/codegen.py +3969 -0
- warp/config.py +158 -0
- warp/constants.py +57 -0
- warp/context.py +6812 -0
- warp/dlpack.py +462 -0
- warp/examples/__init__.py +24 -0
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +389 -0
- warp/examples/benchmarks/benchmark_cloth.py +296 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
- warp/examples/benchmarks/benchmark_gemm.py +164 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
- warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
- warp/examples/benchmarks/benchmark_launches.py +301 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +37 -0
- warp/examples/core/example_cupy.py +86 -0
- warp/examples/core/example_dem.py +241 -0
- warp/examples/core/example_fluid.py +299 -0
- warp/examples/core/example_graph_capture.py +150 -0
- warp/examples/core/example_marching_cubes.py +194 -0
- warp/examples/core/example_mesh.py +180 -0
- warp/examples/core/example_mesh_intersect.py +211 -0
- warp/examples/core/example_nvdb.py +182 -0
- warp/examples/core/example_raycast.py +111 -0
- warp/examples/core/example_raymarch.py +205 -0
- warp/examples/core/example_render_opengl.py +193 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +411 -0
- warp/examples/core/example_torch.py +211 -0
- warp/examples/core/example_wave.py +269 -0
- warp/examples/fem/example_adaptive_grid.py +286 -0
- warp/examples/fem/example_apic_fluid.py +423 -0
- warp/examples/fem/example_burgers.py +261 -0
- warp/examples/fem/example_convection_diffusion.py +178 -0
- warp/examples/fem/example_convection_diffusion_dg.py +204 -0
- warp/examples/fem/example_deformed_geometry.py +172 -0
- warp/examples/fem/example_diffusion.py +196 -0
- warp/examples/fem/example_diffusion_3d.py +225 -0
- warp/examples/fem/example_diffusion_mgpu.py +220 -0
- warp/examples/fem/example_distortion_energy.py +228 -0
- warp/examples/fem/example_magnetostatics.py +240 -0
- warp/examples/fem/example_mixed_elasticity.py +291 -0
- warp/examples/fem/example_navier_stokes.py +261 -0
- warp/examples/fem/example_nonconforming_contact.py +298 -0
- warp/examples/fem/example_stokes.py +213 -0
- warp/examples/fem/example_stokes_transfer.py +262 -0
- warp/examples/fem/example_streamlines.py +352 -0
- warp/examples/fem/utils.py +1000 -0
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +266 -0
- warp/examples/optim/example_cloth_throw.py +228 -0
- warp/examples/optim/example_diffray.py +561 -0
- warp/examples/optim/example_drone.py +870 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +182 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
- warp/examples/optim/example_softbody_properties.py +400 -0
- warp/examples/optim/example_spring_cage.py +245 -0
- warp/examples/optim/example_trajectory.py +227 -0
- warp/examples/sim/example_cartpole.py +143 -0
- warp/examples/sim/example_cloth.py +225 -0
- warp/examples/sim/example_cloth_self_contact.py +322 -0
- warp/examples/sim/example_granular.py +130 -0
- warp/examples/sim/example_granular_collision_sdf.py +202 -0
- warp/examples/sim/example_jacobian_ik.py +244 -0
- warp/examples/sim/example_particle_chain.py +124 -0
- warp/examples/sim/example_quadruped.py +203 -0
- warp/examples/sim/example_rigid_chain.py +203 -0
- warp/examples/sim/example_rigid_contact.py +195 -0
- warp/examples/sim/example_rigid_force.py +133 -0
- warp/examples/sim/example_rigid_gyroscopic.py +115 -0
- warp/examples/sim/example_rigid_soft_contact.py +140 -0
- warp/examples/sim/example_soft_body.py +196 -0
- warp/examples/tile/example_tile_cholesky.py +87 -0
- warp/examples/tile/example_tile_convolution.py +66 -0
- warp/examples/tile/example_tile_fft.py +55 -0
- warp/examples/tile/example_tile_filtering.py +113 -0
- warp/examples/tile/example_tile_matmul.py +85 -0
- warp/examples/tile/example_tile_mlp.py +383 -0
- warp/examples/tile/example_tile_nbody.py +199 -0
- warp/examples/tile/example_tile_walker.py +327 -0
- warp/fabric.py +355 -0
- warp/fem/__init__.py +106 -0
- warp/fem/adaptivity.py +508 -0
- warp/fem/cache.py +572 -0
- warp/fem/dirichlet.py +202 -0
- warp/fem/domain.py +411 -0
- warp/fem/field/__init__.py +125 -0
- warp/fem/field/field.py +619 -0
- warp/fem/field/nodal_field.py +326 -0
- warp/fem/field/restriction.py +37 -0
- warp/fem/field/virtual.py +848 -0
- warp/fem/geometry/__init__.py +32 -0
- warp/fem/geometry/adaptive_nanogrid.py +857 -0
- warp/fem/geometry/closest_point.py +84 -0
- warp/fem/geometry/deformed_geometry.py +221 -0
- warp/fem/geometry/element.py +776 -0
- warp/fem/geometry/geometry.py +362 -0
- warp/fem/geometry/grid_2d.py +392 -0
- warp/fem/geometry/grid_3d.py +452 -0
- warp/fem/geometry/hexmesh.py +911 -0
- warp/fem/geometry/nanogrid.py +571 -0
- warp/fem/geometry/partition.py +389 -0
- warp/fem/geometry/quadmesh.py +663 -0
- warp/fem/geometry/tetmesh.py +855 -0
- warp/fem/geometry/trimesh.py +806 -0
- warp/fem/integrate.py +2335 -0
- warp/fem/linalg.py +419 -0
- warp/fem/operator.py +293 -0
- warp/fem/polynomial.py +229 -0
- warp/fem/quadrature/__init__.py +17 -0
- warp/fem/quadrature/pic_quadrature.py +299 -0
- warp/fem/quadrature/quadrature.py +591 -0
- warp/fem/space/__init__.py +228 -0
- warp/fem/space/basis_function_space.py +468 -0
- warp/fem/space/basis_space.py +667 -0
- warp/fem/space/dof_mapper.py +251 -0
- warp/fem/space/function_space.py +309 -0
- warp/fem/space/grid_2d_function_space.py +177 -0
- warp/fem/space/grid_3d_function_space.py +227 -0
- warp/fem/space/hexmesh_function_space.py +257 -0
- warp/fem/space/nanogrid_function_space.py +201 -0
- warp/fem/space/partition.py +367 -0
- warp/fem/space/quadmesh_function_space.py +223 -0
- warp/fem/space/restriction.py +179 -0
- warp/fem/space/shape/__init__.py +143 -0
- warp/fem/space/shape/cube_shape_function.py +1105 -0
- warp/fem/space/shape/shape_function.py +133 -0
- warp/fem/space/shape/square_shape_function.py +926 -0
- warp/fem/space/shape/tet_shape_function.py +834 -0
- warp/fem/space/shape/triangle_shape_function.py +672 -0
- warp/fem/space/tetmesh_function_space.py +271 -0
- warp/fem/space/topology.py +424 -0
- warp/fem/space/trimesh_function_space.py +194 -0
- warp/fem/types.py +99 -0
- warp/fem/utils.py +420 -0
- warp/jax.py +187 -0
- warp/jax_experimental/__init__.py +16 -0
- warp/jax_experimental/custom_call.py +351 -0
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +244 -0
- warp/native/array.h +1145 -0
- warp/native/builtin.h +1800 -0
- warp/native/bvh.cpp +492 -0
- warp/native/bvh.cu +791 -0
- warp/native/bvh.h +554 -0
- warp/native/clang/clang.cpp +536 -0
- warp/native/coloring.cpp +613 -0
- warp/native/crt.cpp +51 -0
- warp/native/crt.h +362 -0
- warp/native/cuda_crt.h +1058 -0
- warp/native/cuda_util.cpp +646 -0
- warp/native/cuda_util.h +307 -0
- warp/native/error.cpp +77 -0
- warp/native/error.h +36 -0
- warp/native/exports.h +1878 -0
- warp/native/fabric.h +245 -0
- warp/native/hashgrid.cpp +311 -0
- warp/native/hashgrid.cu +87 -0
- warp/native/hashgrid.h +240 -0
- warp/native/initializer_array.h +41 -0
- warp/native/intersect.h +1230 -0
- warp/native/intersect_adj.h +375 -0
- warp/native/intersect_tri.h +339 -0
- warp/native/marching.cpp +19 -0
- warp/native/marching.cu +514 -0
- warp/native/marching.h +19 -0
- warp/native/mat.h +2220 -0
- warp/native/mathdx.cpp +87 -0
- warp/native/matnn.h +343 -0
- warp/native/mesh.cpp +266 -0
- warp/native/mesh.cu +404 -0
- warp/native/mesh.h +1980 -0
- warp/native/nanovdb/GridHandle.h +366 -0
- warp/native/nanovdb/HostBuffer.h +590 -0
- warp/native/nanovdb/NanoVDB.h +6624 -0
- warp/native/nanovdb/PNanoVDB.h +3390 -0
- warp/native/noise.h +859 -0
- warp/native/quat.h +1371 -0
- warp/native/rand.h +342 -0
- warp/native/range.h +139 -0
- warp/native/reduce.cpp +174 -0
- warp/native/reduce.cu +364 -0
- warp/native/runlength_encode.cpp +79 -0
- warp/native/runlength_encode.cu +61 -0
- warp/native/scan.cpp +47 -0
- warp/native/scan.cu +53 -0
- warp/native/scan.h +23 -0
- warp/native/solid_angle.h +466 -0
- warp/native/sort.cpp +251 -0
- warp/native/sort.cu +277 -0
- warp/native/sort.h +33 -0
- warp/native/sparse.cpp +378 -0
- warp/native/sparse.cu +524 -0
- warp/native/spatial.h +657 -0
- warp/native/svd.h +702 -0
- warp/native/temp_buffer.h +46 -0
- warp/native/tile.h +2584 -0
- warp/native/tile_reduce.h +264 -0
- warp/native/vec.h +1426 -0
- warp/native/volume.cpp +501 -0
- warp/native/volume.cu +67 -0
- warp/native/volume.h +969 -0
- warp/native/volume_builder.cu +477 -0
- warp/native/volume_builder.h +52 -0
- warp/native/volume_impl.h +70 -0
- warp/native/warp.cpp +1082 -0
- warp/native/warp.cu +3636 -0
- warp/native/warp.h +381 -0
- warp/optim/__init__.py +17 -0
- warp/optim/adam.py +163 -0
- warp/optim/linear.py +1137 -0
- warp/optim/sgd.py +112 -0
- warp/paddle.py +407 -0
- warp/render/__init__.py +18 -0
- warp/render/render_opengl.py +3518 -0
- warp/render/render_usd.py +784 -0
- warp/render/utils.py +160 -0
- warp/sim/__init__.py +65 -0
- warp/sim/articulation.py +793 -0
- warp/sim/collide.py +2395 -0
- warp/sim/graph_coloring.py +300 -0
- warp/sim/import_mjcf.py +790 -0
- warp/sim/import_snu.py +227 -0
- warp/sim/import_urdf.py +579 -0
- warp/sim/import_usd.py +894 -0
- warp/sim/inertia.py +324 -0
- warp/sim/integrator.py +242 -0
- warp/sim/integrator_euler.py +1997 -0
- warp/sim/integrator_featherstone.py +2101 -0
- warp/sim/integrator_vbd.py +2048 -0
- warp/sim/integrator_xpbd.py +3292 -0
- warp/sim/model.py +4791 -0
- warp/sim/particles.py +121 -0
- warp/sim/render.py +427 -0
- warp/sim/utils.py +428 -0
- warp/sparse.py +2057 -0
- warp/stubs.py +3333 -0
- warp/tape.py +1203 -0
- warp/tests/__init__.py +1 -0
- warp/tests/__main__.py +4 -0
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/mlp_golden.npy +0 -0
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/spiky.usd +0 -0
- warp/tests/assets/test_grid.nvdb +0 -0
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/test_int32_grid.nvdb +0 -0
- warp/tests/assets/test_vec_grid.nvdb +0 -0
- warp/tests/assets/torus.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -0
- warp/tests/aux_test_class_kernel.py +34 -0
- warp/tests/aux_test_compile_consts_dummy.py +18 -0
- warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
- warp/tests/aux_test_dependent.py +29 -0
- warp/tests/aux_test_grad_customs.py +29 -0
- warp/tests/aux_test_instancing_gc.py +26 -0
- warp/tests/aux_test_module_unload.py +23 -0
- warp/tests/aux_test_name_clash1.py +40 -0
- warp/tests/aux_test_name_clash2.py +40 -0
- warp/tests/aux_test_reference.py +9 -0
- warp/tests/aux_test_reference_reference.py +8 -0
- warp/tests/aux_test_square.py +16 -0
- warp/tests/aux_test_unresolved_func.py +22 -0
- warp/tests/aux_test_unresolved_symbol.py +22 -0
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/cuda/test_async.py +676 -0
- warp/tests/cuda/test_ipc.py +124 -0
- warp/tests/cuda/test_mempool.py +233 -0
- warp/tests/cuda/test_multigpu.py +169 -0
- warp/tests/cuda/test_peer.py +139 -0
- warp/tests/cuda/test_pinned.py +84 -0
- warp/tests/cuda/test_streams.py +634 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/geometry/test_bvh.py +200 -0
- warp/tests/geometry/test_hash_grid.py +221 -0
- warp/tests/geometry/test_marching_cubes.py +74 -0
- warp/tests/geometry/test_mesh.py +316 -0
- warp/tests/geometry/test_mesh_query_aabb.py +399 -0
- warp/tests/geometry/test_mesh_query_point.py +932 -0
- warp/tests/geometry/test_mesh_query_ray.py +311 -0
- warp/tests/geometry/test_volume.py +1103 -0
- warp/tests/geometry/test_volume_write.py +346 -0
- warp/tests/interop/__init__.py +0 -0
- warp/tests/interop/test_dlpack.py +729 -0
- warp/tests/interop/test_jax.py +371 -0
- warp/tests/interop/test_paddle.py +800 -0
- warp/tests/interop/test_torch.py +1001 -0
- warp/tests/run_coverage_serial.py +39 -0
- warp/tests/sim/__init__.py +0 -0
- warp/tests/sim/disabled_kinematics.py +244 -0
- warp/tests/sim/flaky_test_sim_grad.py +290 -0
- warp/tests/sim/test_collision.py +604 -0
- warp/tests/sim/test_coloring.py +258 -0
- warp/tests/sim/test_model.py +224 -0
- warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
- warp/tests/sim/test_sim_kinematics.py +98 -0
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +163 -0
- warp/tests/test_arithmetic.py +1096 -0
- warp/tests/test_array.py +2972 -0
- warp/tests/test_array_reduce.py +156 -0
- warp/tests/test_assert.py +250 -0
- warp/tests/test_atomic.py +153 -0
- warp/tests/test_bool.py +220 -0
- warp/tests/test_builtins_resolution.py +1298 -0
- warp/tests/test_closest_point_edge_edge.py +327 -0
- warp/tests/test_codegen.py +810 -0
- warp/tests/test_codegen_instancing.py +1495 -0
- warp/tests/test_compile_consts.py +215 -0
- warp/tests/test_conditional.py +252 -0
- warp/tests/test_context.py +42 -0
- warp/tests/test_copy.py +238 -0
- warp/tests/test_ctypes.py +638 -0
- warp/tests/test_dense.py +73 -0
- warp/tests/test_devices.py +97 -0
- warp/tests/test_examples.py +482 -0
- warp/tests/test_fabricarray.py +996 -0
- warp/tests/test_fast_math.py +74 -0
- warp/tests/test_fem.py +2003 -0
- warp/tests/test_fp16.py +136 -0
- warp/tests/test_func.py +454 -0
- warp/tests/test_future_annotations.py +98 -0
- warp/tests/test_generics.py +656 -0
- warp/tests/test_grad.py +893 -0
- warp/tests/test_grad_customs.py +339 -0
- warp/tests/test_grad_debug.py +341 -0
- warp/tests/test_implicit_init.py +411 -0
- warp/tests/test_import.py +45 -0
- warp/tests/test_indexedarray.py +1140 -0
- warp/tests/test_intersect.py +73 -0
- warp/tests/test_iter.py +76 -0
- warp/tests/test_large.py +177 -0
- warp/tests/test_launch.py +411 -0
- warp/tests/test_lerp.py +151 -0
- warp/tests/test_linear_solvers.py +193 -0
- warp/tests/test_lvalue.py +427 -0
- warp/tests/test_mat.py +2089 -0
- warp/tests/test_mat_lite.py +122 -0
- warp/tests/test_mat_scalar_ops.py +2913 -0
- warp/tests/test_math.py +178 -0
- warp/tests/test_mlp.py +282 -0
- warp/tests/test_module_hashing.py +258 -0
- warp/tests/test_modules_lite.py +44 -0
- warp/tests/test_noise.py +252 -0
- warp/tests/test_operators.py +299 -0
- warp/tests/test_options.py +129 -0
- warp/tests/test_overwrite.py +551 -0
- warp/tests/test_print.py +339 -0
- warp/tests/test_quat.py +2315 -0
- warp/tests/test_rand.py +339 -0
- warp/tests/test_reload.py +302 -0
- warp/tests/test_rounding.py +185 -0
- warp/tests/test_runlength_encode.py +196 -0
- warp/tests/test_scalar_ops.py +105 -0
- warp/tests/test_smoothstep.py +108 -0
- warp/tests/test_snippet.py +318 -0
- warp/tests/test_sparse.py +582 -0
- warp/tests/test_spatial.py +2229 -0
- warp/tests/test_special_values.py +361 -0
- warp/tests/test_static.py +592 -0
- warp/tests/test_struct.py +734 -0
- warp/tests/test_tape.py +204 -0
- warp/tests/test_transient_module.py +93 -0
- warp/tests/test_triangle_closest_point.py +145 -0
- warp/tests/test_types.py +562 -0
- warp/tests/test_utils.py +588 -0
- warp/tests/test_vec.py +1487 -0
- warp/tests/test_vec_lite.py +80 -0
- warp/tests/test_vec_scalar_ops.py +2327 -0
- warp/tests/test_verify_fp.py +100 -0
- warp/tests/tile/__init__.py +0 -0
- warp/tests/tile/test_tile.py +780 -0
- warp/tests/tile/test_tile_load.py +407 -0
- warp/tests/tile/test_tile_mathdx.py +208 -0
- warp/tests/tile/test_tile_mlp.py +402 -0
- warp/tests/tile/test_tile_reduce.py +447 -0
- warp/tests/tile/test_tile_shared_memory.py +247 -0
- warp/tests/tile/test_tile_view.py +173 -0
- warp/tests/unittest_serial.py +47 -0
- warp/tests/unittest_suites.py +427 -0
- warp/tests/unittest_utils.py +468 -0
- warp/tests/walkthrough_debug.py +93 -0
- warp/thirdparty/__init__.py +0 -0
- warp/thirdparty/appdirs.py +598 -0
- warp/thirdparty/dlpack.py +145 -0
- warp/thirdparty/unittest_parallel.py +570 -0
- warp/torch.py +391 -0
- warp/types.py +5230 -0
- warp/utils.py +1137 -0
- warp_lang-1.7.0.dist-info/METADATA +516 -0
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- warp_lang-1.7.0.dist-info/WHEEL +5 -0
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp_lang-1.7.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
###########################################################################
|
|
17
|
+
# Example Tile Convolution
|
|
18
|
+
#
|
|
19
|
+
# Shows how to write a simple convolution kernel using Warp FFT tile
|
|
20
|
+
# primitives.
|
|
21
|
+
#
|
|
22
|
+
###########################################################################
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
|
|
26
|
+
import warp as wp
|
|
27
|
+
|
|
28
|
+
wp.set_module_options({"enable_backward": False})
|
|
29
|
+
|
|
30
|
+
BLOCK_DIM = 64
|
|
31
|
+
TILE_M = 1
|
|
32
|
+
TILE_N = 128
|
|
33
|
+
|
|
34
|
+
scale = wp.vec2d(wp.float64(1 / TILE_N), wp.float64(1 / TILE_N))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@wp.func
|
|
38
|
+
def filter(x: wp.vec2d):
|
|
39
|
+
return wp.cw_mul(x, scale)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@wp.kernel
|
|
43
|
+
def conv_tiled(x: wp.array2d(dtype=wp.vec2d), y: wp.array2d(dtype=wp.vec2d)):
|
|
44
|
+
i, j, _ = wp.tid()
|
|
45
|
+
a = wp.tile_load(x, shape=(TILE_M, TILE_N))
|
|
46
|
+
wp.tile_fft(a)
|
|
47
|
+
b = wp.tile_map(filter, a)
|
|
48
|
+
wp.tile_ifft(b)
|
|
49
|
+
wp.tile_store(y, b)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
if __name__ == "__main__":
|
|
53
|
+
wp.set_device("cuda:0")
|
|
54
|
+
|
|
55
|
+
rng = np.random.default_rng(42)
|
|
56
|
+
|
|
57
|
+
x_h = rng.standard_normal((TILE_M, TILE_N, 2), dtype=np.float64)
|
|
58
|
+
y_h = np.zeros_like(x_h)
|
|
59
|
+
|
|
60
|
+
x_wp = wp.array2d(x_h, dtype=wp.vec2d)
|
|
61
|
+
y_wp = wp.array2d(y_h, dtype=wp.vec2d)
|
|
62
|
+
|
|
63
|
+
wp.launch_tiled(conv_tiled, dim=[1, 1], inputs=[x_wp], outputs=[y_wp], block_dim=BLOCK_DIM)
|
|
64
|
+
|
|
65
|
+
# Since filter is 1/N, conv_tiled is a ~no-op
|
|
66
|
+
assert np.allclose(x_h, y_wp.numpy())
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
###########################################################################
|
|
17
|
+
# Example Tile FFT
|
|
18
|
+
#
|
|
19
|
+
# Shows how to write a simple FFT kernel using Warp tile primitives.
|
|
20
|
+
#
|
|
21
|
+
###########################################################################
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
|
|
25
|
+
import warp as wp
|
|
26
|
+
|
|
27
|
+
wp.set_module_options({"enable_backward": False})
|
|
28
|
+
|
|
29
|
+
BLOCK_DIM = 8
|
|
30
|
+
TILE_M = 1
|
|
31
|
+
TILE_N = 32
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@wp.kernel
|
|
35
|
+
def fft_tiled(x: wp.array2d(dtype=wp.vec2d), y: wp.array2d(dtype=wp.vec2d)):
|
|
36
|
+
i, j, _ = wp.tid()
|
|
37
|
+
a = wp.tile_load(x, shape=(TILE_M, TILE_N))
|
|
38
|
+
wp.tile_fft(a)
|
|
39
|
+
wp.tile_ifft(a)
|
|
40
|
+
wp.tile_store(y, a)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
if __name__ == "__main__":
|
|
44
|
+
wp.set_device("cuda:0")
|
|
45
|
+
|
|
46
|
+
x_h = np.ones((TILE_M, TILE_N, 2), dtype=np.float64)
|
|
47
|
+
x_h[:, :, 1] = 0
|
|
48
|
+
y_h = 3 * np.ones((TILE_M, TILE_N, 2), dtype=np.float64)
|
|
49
|
+
x_wp = wp.array2d(x_h, dtype=wp.vec2d)
|
|
50
|
+
y_wp = wp.array2d(y_h, dtype=wp.vec2d)
|
|
51
|
+
|
|
52
|
+
wp.launch_tiled(fft_tiled, dim=[1, 1], inputs=[x_wp], outputs=[y_wp], block_dim=BLOCK_DIM)
|
|
53
|
+
|
|
54
|
+
print("Inputs:\n", x_wp) # [1+0i, 1+0i, 1+0i, ...]
|
|
55
|
+
print("Output:\n", y_wp) # [32+0i, 0, 0, ...]
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
###########################################################################
|
|
17
|
+
# Example Tile Filtering
|
|
18
|
+
#
|
|
19
|
+
# Shows how to write a simple filtering kernel using Warp FFT tile
|
|
20
|
+
# primitives.
|
|
21
|
+
#
|
|
22
|
+
###########################################################################
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
|
|
26
|
+
import warp as wp
|
|
27
|
+
|
|
28
|
+
wp.set_module_options({"enable_backward": False})
|
|
29
|
+
|
|
30
|
+
BLOCK_DIM = 128
|
|
31
|
+
TILE_M = 1
|
|
32
|
+
TILE_N = 512
|
|
33
|
+
|
|
34
|
+
scale = wp.vec2d(wp.float64(1 / TILE_N), wp.float64(1 / TILE_N))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def cplx(array):
|
|
38
|
+
return array[..., 0] + 1j * array[..., 1]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@wp.func
|
|
42
|
+
def cplx_prod(x: wp.vec2d, y: wp.vec2d):
|
|
43
|
+
return wp.cw_mul(wp.vec2d(x[0] * y[0] - x[1] * y[1], x[0] * y[1] + x[1] * y[0]), scale)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@wp.kernel
|
|
47
|
+
def conv_tiled(x: wp.array2d(dtype=wp.vec2d), y: wp.array2d(dtype=wp.vec2d), z: wp.array2d(dtype=wp.vec2d)):
|
|
48
|
+
i, j, _ = wp.tid()
|
|
49
|
+
a = wp.tile_load(x, shape=(TILE_M, TILE_N))
|
|
50
|
+
b = wp.tile_load(y, shape=(TILE_M, TILE_N))
|
|
51
|
+
wp.tile_fft(a)
|
|
52
|
+
c = wp.tile_map(cplx_prod, a, b)
|
|
53
|
+
wp.tile_ifft(c)
|
|
54
|
+
wp.tile_store(z, c)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
if __name__ == "__main__":
|
|
58
|
+
rng = np.random.default_rng(42)
|
|
59
|
+
|
|
60
|
+
# Create noisy input signal
|
|
61
|
+
t = np.linspace(0, 2 * np.pi, TILE_N, dtype=np.float64)
|
|
62
|
+
x = np.sin(t) + 0.5 * rng.random(TILE_N, dtype=np.float64)
|
|
63
|
+
|
|
64
|
+
# Create filter. This filter keeps only ~10% of the frequencies at the center
|
|
65
|
+
# of the spectrum.
|
|
66
|
+
f = np.ones_like(x)
|
|
67
|
+
freq = np.fft.fftfreq(TILE_N)
|
|
68
|
+
f[np.abs(freq) > 0.05] = 0.0
|
|
69
|
+
f[np.abs(freq) <= 0.05] = 1.0
|
|
70
|
+
|
|
71
|
+
# Create Warp input data
|
|
72
|
+
# We use vec2d to hold complex numbers
|
|
73
|
+
x_h = np.zeros((TILE_M, TILE_N, 2), dtype=np.float64)
|
|
74
|
+
f_h = np.zeros_like(x_h)
|
|
75
|
+
y_h = np.zeros_like(f_h)
|
|
76
|
+
|
|
77
|
+
x_h[:, :, 0] = x
|
|
78
|
+
f_h[:, :, 0] = f
|
|
79
|
+
|
|
80
|
+
x_wp = wp.array2d(x_h, dtype=wp.vec2d)
|
|
81
|
+
f_wp = wp.array2d(f_h, dtype=wp.vec2d)
|
|
82
|
+
y_wp = wp.array2d(y_h, dtype=wp.vec2d)
|
|
83
|
+
|
|
84
|
+
wp.launch_tiled(conv_tiled, dim=[1, 1], inputs=[x_wp, f_wp], outputs=[y_wp], block_dim=BLOCK_DIM)
|
|
85
|
+
|
|
86
|
+
# Extract output and compare with numpy
|
|
87
|
+
x_np = cplx(x_h)
|
|
88
|
+
f_np = cplx(f_h)
|
|
89
|
+
y_test = cplx(y_wp.numpy())
|
|
90
|
+
y_ref = np.fft.ifft(f_np * np.fft.fft(x_np))
|
|
91
|
+
assert np.allclose(y_ref, y_test)
|
|
92
|
+
|
|
93
|
+
try:
|
|
94
|
+
import matplotlib.pyplot as plt
|
|
95
|
+
|
|
96
|
+
fig, ax = plt.subplots(figsize=(10, 5))
|
|
97
|
+
|
|
98
|
+
ax.plot(
|
|
99
|
+
x,
|
|
100
|
+
color="#DDDDDD",
|
|
101
|
+
linewidth=2,
|
|
102
|
+
label="Original",
|
|
103
|
+
)
|
|
104
|
+
ax.plot(y_test[0, :].real, color="#76B900", linewidth=3, label="Smoothed")
|
|
105
|
+
|
|
106
|
+
ax.legend()
|
|
107
|
+
ax.grid(True)
|
|
108
|
+
|
|
109
|
+
plt.tight_layout()
|
|
110
|
+
plt.show()
|
|
111
|
+
|
|
112
|
+
except ModuleNotFoundError:
|
|
113
|
+
print("Matplotlib not available; skipping figure")
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
###########################################################################
|
|
17
|
+
# Example Tile MatMul
|
|
18
|
+
#
|
|
19
|
+
# Shows how to write a simple GEMM kernel using Warp tile primitives.
|
|
20
|
+
#
|
|
21
|
+
###########################################################################
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
|
|
25
|
+
import warp as wp
|
|
26
|
+
|
|
27
|
+
# tile size
|
|
28
|
+
TILE_M = wp.constant(8)
|
|
29
|
+
TILE_N = wp.constant(4)
|
|
30
|
+
TILE_K = wp.constant(8)
|
|
31
|
+
|
|
32
|
+
# num threads per-tile
|
|
33
|
+
TILE_THREADS = 64
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@wp.kernel
|
|
37
|
+
def tile_gemm(A: wp.array2d(dtype=wp.float32), B: wp.array2d(dtype=wp.float16), C: wp.array2d(dtype=wp.float64)):
|
|
38
|
+
# output tile index
|
|
39
|
+
i, j = wp.tid()
|
|
40
|
+
|
|
41
|
+
sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=wp.float64)
|
|
42
|
+
|
|
43
|
+
_M = A.shape[0]
|
|
44
|
+
_N = B.shape[1]
|
|
45
|
+
K = A.shape[1]
|
|
46
|
+
|
|
47
|
+
count = int(K / TILE_K)
|
|
48
|
+
|
|
49
|
+
for k in range(0, count):
|
|
50
|
+
a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i * TILE_M, k * TILE_K))
|
|
51
|
+
b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k * TILE_K, j * TILE_N))
|
|
52
|
+
|
|
53
|
+
# sum += a*b
|
|
54
|
+
wp.tile_matmul(a, b, sum)
|
|
55
|
+
|
|
56
|
+
wp.tile_store(C, sum, offset=(i * TILE_M, j * TILE_N))
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
if __name__ == "__main__":
|
|
60
|
+
# generate some tile aligned matrix dimensions
|
|
61
|
+
M = TILE_M * 7
|
|
62
|
+
K = TILE_K * 6
|
|
63
|
+
N = TILE_N * 5
|
|
64
|
+
|
|
65
|
+
rng = np.random.default_rng(42)
|
|
66
|
+
A = rng.random((M, K), dtype=np.float32)
|
|
67
|
+
B = rng.random((K, N), dtype=np.float32).astype(np.float16)
|
|
68
|
+
C = np.zeros((M, N), dtype=np.float64)
|
|
69
|
+
|
|
70
|
+
A_wp = wp.array(A, requires_grad=True)
|
|
71
|
+
B_wp = wp.array(B, requires_grad=True)
|
|
72
|
+
C_wp = wp.array(C, requires_grad=True)
|
|
73
|
+
|
|
74
|
+
with wp.Tape() as tape:
|
|
75
|
+
wp.launch_tiled(
|
|
76
|
+
tile_gemm,
|
|
77
|
+
dim=(M // TILE_M, N // TILE_N),
|
|
78
|
+
inputs=[A_wp, B_wp],
|
|
79
|
+
outputs=[C_wp],
|
|
80
|
+
block_dim=TILE_THREADS,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
assert np.allclose(C_wp.numpy(), A @ B, atol=1.0e-4)
|
|
84
|
+
|
|
85
|
+
print("Example matrix multiplication passed")
|
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
###########################################################################
|
|
17
|
+
# Example Image Multilayer Perceptron (MLP)
|
|
18
|
+
#
|
|
19
|
+
# Shows how to train a coordinate-based MLP on an image to predict the RGB
|
|
20
|
+
# color at a given input position. By default, a positional encoding is
|
|
21
|
+
# applied to the input coordinates to improve the ability of the MLP to
|
|
22
|
+
# represent higher-frequency content. This can be disabled by passing the
|
|
23
|
+
# '--no_encoding' option.
|
|
24
|
+
#
|
|
25
|
+
# References:
|
|
26
|
+
# Ben Mildenhall et al. 2021. NeRF: representing scenes
|
|
27
|
+
# as neural radiance fields for view synthesis. Commun. ACM 65, 1
|
|
28
|
+
# (January 2022), 99–106. https://doi.org/10.1145/3503250
|
|
29
|
+
#
|
|
30
|
+
###########################################################################
|
|
31
|
+
|
|
32
|
+
import math
|
|
33
|
+
import os
|
|
34
|
+
|
|
35
|
+
import numpy as np
|
|
36
|
+
from PIL import Image
|
|
37
|
+
|
|
38
|
+
import warp as wp
|
|
39
|
+
import warp.examples
|
|
40
|
+
import warp.optim
|
|
41
|
+
|
|
42
|
+
rng = np.random.default_rng(45)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def create_layer(dim_in, dim_hid, dtype=float):
|
|
46
|
+
w = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, dim_in))
|
|
47
|
+
b = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, 1))
|
|
48
|
+
|
|
49
|
+
weights = wp.array(w, dtype=dtype, requires_grad=True)
|
|
50
|
+
bias = wp.array(b, dtype=dtype, requires_grad=True)
|
|
51
|
+
|
|
52
|
+
return (weights, bias)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def create_array(dim_in, dim_hid, dtype=float):
|
|
56
|
+
s = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, dim_in))
|
|
57
|
+
a = wp.array(s, dtype=dtype, requires_grad=True)
|
|
58
|
+
|
|
59
|
+
return a
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# number of frequencies for the positional encoding
|
|
63
|
+
NUM_FREQ = wp.constant(8)
|
|
64
|
+
|
|
65
|
+
DIM_IN = wp.constant(4 * NUM_FREQ) # sin,cos for both x,y at each frequenecy
|
|
66
|
+
DIM_HID = 32
|
|
67
|
+
DIM_OUT = 3
|
|
68
|
+
|
|
69
|
+
# threads per-block
|
|
70
|
+
NUM_THREADS = 32
|
|
71
|
+
|
|
72
|
+
IMG_WIDTH = 512
|
|
73
|
+
IMG_HEIGHT = 512
|
|
74
|
+
|
|
75
|
+
BATCH_SIZE = min(1024, int((IMG_WIDTH * IMG_HEIGHT) / 8))
|
|
76
|
+
|
|
77
|
+
# dtype for our weights and bias matrices
|
|
78
|
+
dtype = wp.float16
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@wp.func
|
|
82
|
+
def relu(x: dtype):
|
|
83
|
+
return wp.max(x, dtype(0.0))
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@wp.kernel
|
|
87
|
+
def compute(
|
|
88
|
+
indices: wp.array(dtype=int),
|
|
89
|
+
weights_0: wp.array2d(dtype=dtype),
|
|
90
|
+
bias_0: wp.array2d(dtype=dtype),
|
|
91
|
+
weights_1: wp.array2d(dtype=dtype),
|
|
92
|
+
bias_1: wp.array2d(dtype=dtype),
|
|
93
|
+
weights_2: wp.array2d(dtype=dtype),
|
|
94
|
+
bias_2: wp.array2d(dtype=dtype),
|
|
95
|
+
weights_3: wp.array2d(dtype=dtype),
|
|
96
|
+
bias_3: wp.array2d(dtype=dtype),
|
|
97
|
+
reference: wp.array2d(dtype=float),
|
|
98
|
+
loss: wp.array1d(dtype=float),
|
|
99
|
+
out: wp.array2d(dtype=float),
|
|
100
|
+
):
|
|
101
|
+
# batch indices
|
|
102
|
+
linear = indices[wp.tid()]
|
|
103
|
+
|
|
104
|
+
row = linear / IMG_WIDTH
|
|
105
|
+
col = linear % IMG_WIDTH
|
|
106
|
+
|
|
107
|
+
# normalize input coordinates to [-1, 1]
|
|
108
|
+
x = (float(row) / float(IMG_WIDTH) - 0.5) * 2.0
|
|
109
|
+
y = (float(col) / float(IMG_HEIGHT) - 0.5) * 2.0
|
|
110
|
+
|
|
111
|
+
local = wp.vector(dtype=dtype, length=DIM_IN)
|
|
112
|
+
|
|
113
|
+
# construct positional encoding
|
|
114
|
+
for s in range(NUM_FREQ):
|
|
115
|
+
scale = wp.pow(2.0, float(s)) * wp.pi
|
|
116
|
+
|
|
117
|
+
# x-coord
|
|
118
|
+
local[s * 4 + 0] = dtype(wp.sin(x * scale))
|
|
119
|
+
local[s * 4 + 1] = dtype(wp.cos(x * scale))
|
|
120
|
+
# y-coord
|
|
121
|
+
local[s * 4 + 2] = dtype(wp.sin(y * scale))
|
|
122
|
+
local[s * 4 + 3] = dtype(wp.cos(y * scale))
|
|
123
|
+
|
|
124
|
+
# tile feature vectors across the block, returns [dim(f), NUM_THREADS]
|
|
125
|
+
f = wp.tile(local)
|
|
126
|
+
|
|
127
|
+
# input layer
|
|
128
|
+
w0 = wp.tile_load(weights_0, shape=(DIM_HID, DIM_IN))
|
|
129
|
+
b0 = wp.tile_load(bias_0, shape=(DIM_HID, 1))
|
|
130
|
+
z = wp.tile_map(relu, wp.tile_matmul(w0, f) + wp.tile_broadcast(b0, shape=(DIM_HID, NUM_THREADS)))
|
|
131
|
+
|
|
132
|
+
# hidden layer
|
|
133
|
+
w1 = wp.tile_load(weights_1, shape=(DIM_HID, DIM_HID))
|
|
134
|
+
b1 = wp.tile_load(bias_1, shape=(DIM_HID, 1))
|
|
135
|
+
z = wp.tile_map(relu, wp.tile_matmul(w1, z) + wp.tile_broadcast(b1, shape=(DIM_HID, NUM_THREADS)))
|
|
136
|
+
|
|
137
|
+
w2 = wp.tile_load(weights_2, shape=(DIM_HID, DIM_HID))
|
|
138
|
+
b2 = wp.tile_load(bias_2, shape=(DIM_HID, 1))
|
|
139
|
+
z = wp.tile_map(relu, wp.tile_matmul(w2, z) + wp.tile_broadcast(b2, shape=(DIM_HID, NUM_THREADS)))
|
|
140
|
+
|
|
141
|
+
# output layer
|
|
142
|
+
w3 = wp.tile_load(weights_3, shape=(DIM_OUT, DIM_HID))
|
|
143
|
+
b3 = wp.tile_load(bias_3, shape=(DIM_OUT, 1))
|
|
144
|
+
o = wp.tile_map(relu, wp.tile_matmul(w3, z) + wp.tile_broadcast(b3, shape=(DIM_OUT, NUM_THREADS)))
|
|
145
|
+
|
|
146
|
+
# untile back to SIMT
|
|
147
|
+
output = wp.untile(o)
|
|
148
|
+
|
|
149
|
+
# compute error
|
|
150
|
+
error = wp.vec3(
|
|
151
|
+
float(output[0]) - reference[0, linear],
|
|
152
|
+
float(output[1]) - reference[1, linear],
|
|
153
|
+
float(output[2]) - reference[2, linear],
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# write MSE loss
|
|
157
|
+
if loss:
|
|
158
|
+
wp.atomic_add(loss, 0, wp.length_sq(error) / float(3 * BATCH_SIZE))
|
|
159
|
+
|
|
160
|
+
# write image output
|
|
161
|
+
if out:
|
|
162
|
+
for i in range(DIM_OUT):
|
|
163
|
+
out[i, linear] = float(output[i])
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class Example:
|
|
167
|
+
def __init__(self, train_iters):
|
|
168
|
+
self.weights_0, self.bias_0 = create_layer(DIM_IN, DIM_HID, dtype=dtype)
|
|
169
|
+
self.weights_1, self.bias_1 = create_layer(DIM_HID, DIM_HID, dtype=dtype)
|
|
170
|
+
self.weights_2, self.bias_2 = create_layer(DIM_HID, DIM_HID, dtype=dtype)
|
|
171
|
+
self.weights_3, self.bias_3 = create_layer(DIM_HID, DIM_OUT, dtype=dtype)
|
|
172
|
+
|
|
173
|
+
# reference
|
|
174
|
+
reference_path = os.path.join(wp.examples.get_asset_directory(), "pixel.jpg")
|
|
175
|
+
with Image.open(reference_path) as im:
|
|
176
|
+
reference_image = np.asarray(im.resize((IMG_WIDTH, IMG_HEIGHT)).convert("RGB")) / 255.0
|
|
177
|
+
self.reference = wp.array(reference_image.reshape(IMG_WIDTH * IMG_HEIGHT, 3).T, dtype=float)
|
|
178
|
+
|
|
179
|
+
# create randomized batch indices
|
|
180
|
+
indices = np.arange(0, IMG_WIDTH * IMG_HEIGHT, dtype=np.int32)
|
|
181
|
+
rng.shuffle(indices)
|
|
182
|
+
self.indices = wp.array(indices)
|
|
183
|
+
|
|
184
|
+
self.num_batches = int((IMG_WIDTH * IMG_HEIGHT) / BATCH_SIZE)
|
|
185
|
+
self.max_iters = train_iters
|
|
186
|
+
self.max_epochs = max(1, int(self.max_iters / self.num_batches))
|
|
187
|
+
|
|
188
|
+
def train_warp(self):
|
|
189
|
+
params = [
|
|
190
|
+
self.weights_0,
|
|
191
|
+
self.bias_0,
|
|
192
|
+
self.weights_1,
|
|
193
|
+
self.bias_1,
|
|
194
|
+
self.weights_2,
|
|
195
|
+
self.bias_2,
|
|
196
|
+
self.weights_3,
|
|
197
|
+
self.bias_3,
|
|
198
|
+
]
|
|
199
|
+
|
|
200
|
+
optimizer_grads = [p.grad.flatten() for p in params]
|
|
201
|
+
optimizer_inputs = [p.flatten() for p in params]
|
|
202
|
+
optimizer = warp.optim.Adam(optimizer_inputs, lr=0.01)
|
|
203
|
+
|
|
204
|
+
loss = wp.zeros(1, dtype=float, requires_grad=True)
|
|
205
|
+
output = create_array(IMG_WIDTH * IMG_HEIGHT, DIM_OUT)
|
|
206
|
+
|
|
207
|
+
# capture graph for whole epoch
|
|
208
|
+
wp.capture_begin()
|
|
209
|
+
|
|
210
|
+
for b in range(0, IMG_WIDTH * IMG_HEIGHT, BATCH_SIZE):
|
|
211
|
+
loss.zero_()
|
|
212
|
+
|
|
213
|
+
with wp.Tape() as tape:
|
|
214
|
+
wp.launch(
|
|
215
|
+
compute,
|
|
216
|
+
dim=[BATCH_SIZE],
|
|
217
|
+
inputs=[
|
|
218
|
+
self.indices[b : b + BATCH_SIZE],
|
|
219
|
+
self.weights_0,
|
|
220
|
+
self.bias_0,
|
|
221
|
+
self.weights_1,
|
|
222
|
+
self.bias_1,
|
|
223
|
+
self.weights_2,
|
|
224
|
+
self.bias_2,
|
|
225
|
+
self.weights_3,
|
|
226
|
+
self.bias_3,
|
|
227
|
+
self.reference,
|
|
228
|
+
loss,
|
|
229
|
+
None,
|
|
230
|
+
],
|
|
231
|
+
block_dim=NUM_THREADS,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
tape.backward(loss)
|
|
235
|
+
optimizer.step(optimizer_grads)
|
|
236
|
+
tape.zero()
|
|
237
|
+
|
|
238
|
+
graph = wp.capture_end()
|
|
239
|
+
|
|
240
|
+
with wp.ScopedTimer("Training"):
|
|
241
|
+
for i in range(self.max_epochs):
|
|
242
|
+
with wp.ScopedTimer("Epoch"):
|
|
243
|
+
wp.capture_launch(graph)
|
|
244
|
+
print(f"Epoch: {i} Loss: {loss.numpy()}")
|
|
245
|
+
|
|
246
|
+
# evaluate full image
|
|
247
|
+
wp.launch(
|
|
248
|
+
compute,
|
|
249
|
+
dim=[IMG_WIDTH * IMG_HEIGHT],
|
|
250
|
+
inputs=[
|
|
251
|
+
self.indices,
|
|
252
|
+
self.weights_0,
|
|
253
|
+
self.bias_0,
|
|
254
|
+
self.weights_1,
|
|
255
|
+
self.bias_1,
|
|
256
|
+
self.weights_2,
|
|
257
|
+
self.bias_2,
|
|
258
|
+
self.weights_3,
|
|
259
|
+
self.bias_3,
|
|
260
|
+
self.reference,
|
|
261
|
+
loss,
|
|
262
|
+
output,
|
|
263
|
+
],
|
|
264
|
+
block_dim=NUM_THREADS,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
self.save_image("example_tile_mlp.jpg", output.numpy())
|
|
268
|
+
|
|
269
|
+
def train_torch(self):
|
|
270
|
+
import torch as tc
|
|
271
|
+
|
|
272
|
+
weights_0 = tc.nn.Parameter(wp.to_torch(self.weights_0))
|
|
273
|
+
weights_1 = tc.nn.Parameter(wp.to_torch(self.weights_1))
|
|
274
|
+
weights_2 = tc.nn.Parameter(wp.to_torch(self.weights_2))
|
|
275
|
+
weights_3 = tc.nn.Parameter(wp.to_torch(self.weights_3))
|
|
276
|
+
|
|
277
|
+
bias_0 = tc.nn.Parameter(wp.to_torch(self.bias_0))
|
|
278
|
+
bias_1 = tc.nn.Parameter(wp.to_torch(self.bias_1))
|
|
279
|
+
bias_2 = tc.nn.Parameter(wp.to_torch(self.bias_2))
|
|
280
|
+
bias_3 = tc.nn.Parameter(wp.to_torch(self.bias_3))
|
|
281
|
+
|
|
282
|
+
indices = wp.to_torch(self.indices)
|
|
283
|
+
reference = wp.to_torch(self.reference)
|
|
284
|
+
|
|
285
|
+
optimizer = tc.optim.Adam(
|
|
286
|
+
[weights_0, bias_0, weights_1, bias_1, weights_2, bias_2, weights_3, bias_3],
|
|
287
|
+
capturable=True,
|
|
288
|
+
lr=0.0001,
|
|
289
|
+
betas=(0.9, 0.95),
|
|
290
|
+
eps=1.0e-6,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
# generate frequency space encoding of pixels
|
|
294
|
+
# based on their linear index in the image
|
|
295
|
+
def encode(linear):
|
|
296
|
+
row = (linear // IMG_WIDTH).float()
|
|
297
|
+
col = (linear % IMG_WIDTH).float()
|
|
298
|
+
|
|
299
|
+
x = (row / float(IMG_WIDTH) - 0.5) * 2.0
|
|
300
|
+
y = (col / float(IMG_HEIGHT) - 0.5) * 2.0
|
|
301
|
+
|
|
302
|
+
encoding = tc.zeros((NUM_FREQ * 4, len(linear)), dtype=tc.float16, device="cuda")
|
|
303
|
+
|
|
304
|
+
for s in range(NUM_FREQ):
|
|
305
|
+
scale = math.pow(2.0, float(s)) * math.pi
|
|
306
|
+
|
|
307
|
+
# Directly write the computed values into the encoding tensor
|
|
308
|
+
encoding[s * 4 + 0, :] = tc.sin(scale * x)
|
|
309
|
+
encoding[s * 4 + 1, :] = tc.cos(scale * x)
|
|
310
|
+
encoding[s * 4 + 2, :] = tc.sin(scale * y)
|
|
311
|
+
encoding[s * 4 + 3, :] = tc.cos(scale * y)
|
|
312
|
+
|
|
313
|
+
return encoding
|
|
314
|
+
|
|
315
|
+
stream = tc.cuda.Stream()
|
|
316
|
+
graph = tc.cuda.CUDAGraph()
|
|
317
|
+
|
|
318
|
+
# warm-up
|
|
319
|
+
with tc.cuda.stream(stream):
|
|
320
|
+
f = tc.rand((NUM_FREQ * 4, BATCH_SIZE), dtype=tc.float16, device="cuda")
|
|
321
|
+
z = tc.relu(weights_0 @ f + bias_0)
|
|
322
|
+
z = tc.relu(weights_1 @ z + bias_1)
|
|
323
|
+
z = tc.relu(weights_2 @ z + bias_2)
|
|
324
|
+
z = tc.relu(weights_3 @ z + bias_3)
|
|
325
|
+
ref = tc.rand((3, BATCH_SIZE), dtype=tc.float16, device="cuda")
|
|
326
|
+
loss = tc.mean((z - ref) ** 2)
|
|
327
|
+
optimizer.zero_grad()
|
|
328
|
+
loss.backward()
|
|
329
|
+
optimizer.step()
|
|
330
|
+
|
|
331
|
+
with tc.cuda.graph(graph):
|
|
332
|
+
for b in range(0, IMG_WIDTH * IMG_HEIGHT, BATCH_SIZE):
|
|
333
|
+
linear = indices[b : b + BATCH_SIZE]
|
|
334
|
+
|
|
335
|
+
f = encode(linear)
|
|
336
|
+
|
|
337
|
+
z = tc.relu(weights_0 @ f + bias_0)
|
|
338
|
+
z = tc.relu(weights_1 @ z + bias_1)
|
|
339
|
+
z = tc.relu(weights_2 @ z + bias_2)
|
|
340
|
+
z = tc.relu(weights_3 @ z + bias_3)
|
|
341
|
+
|
|
342
|
+
ref = reference[:, linear]
|
|
343
|
+
loss = tc.mean((z - ref) ** 2)
|
|
344
|
+
|
|
345
|
+
optimizer.zero_grad()
|
|
346
|
+
loss.backward()
|
|
347
|
+
optimizer.step()
|
|
348
|
+
|
|
349
|
+
with wp.ScopedTimer("Training (Torch)"):
|
|
350
|
+
for _i in range(self.max_epochs):
|
|
351
|
+
with wp.ScopedTimer("Epoch"):
|
|
352
|
+
graph.replay()
|
|
353
|
+
|
|
354
|
+
print(loss)
|
|
355
|
+
|
|
356
|
+
f = encode(tc.arange(0, IMG_WIDTH * IMG_HEIGHT))
|
|
357
|
+
z = tc.relu(weights_0 @ f + bias_0)
|
|
358
|
+
z = tc.relu(weights_1 @ z + bias_1)
|
|
359
|
+
z = tc.relu(weights_2 @ z + bias_2)
|
|
360
|
+
z = tc.relu(weights_3 @ z + bias_3)
|
|
361
|
+
|
|
362
|
+
self.save_image("example_tile_mlp_torch.jpg", z.detach().cpu().numpy())
|
|
363
|
+
|
|
364
|
+
def save_image(self, name, output):
|
|
365
|
+
predicted_image = output.T.reshape(IMG_WIDTH, IMG_HEIGHT, 3)
|
|
366
|
+
predicted_image = (predicted_image * 255).astype(np.uint8)
|
|
367
|
+
|
|
368
|
+
predicted_image_pil = Image.fromarray(predicted_image)
|
|
369
|
+
predicted_image_pil.save(name)
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
if __name__ == "__main__":
|
|
373
|
+
import argparse
|
|
374
|
+
|
|
375
|
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
376
|
+
parser.add_argument("--train_iters", type=int, default=20000, help="Total number of training iterations.")
|
|
377
|
+
|
|
378
|
+
args = parser.parse_known_args()[0]
|
|
379
|
+
|
|
380
|
+
with wp.ScopedDevice("cuda:0"):
|
|
381
|
+
example = Example(args.train_iters)
|
|
382
|
+
example.train_warp()
|
|
383
|
+
# example.train_torch()
|