warp-lang 1.6.0__py3-none-manylinux2014_x86_64.whl → 1.6.2__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +14 -6
- warp/autograd.py +14 -6
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +14 -6
- warp/build_dll.py +14 -6
- warp/builtins.py +16 -7
- warp/codegen.py +24 -9
- warp/config.py +79 -27
- warp/constants.py +14 -6
- warp/context.py +236 -71
- warp/dlpack.py +14 -6
- warp/examples/__init__.py +14 -6
- warp/examples/benchmarks/benchmark_api.py +14 -6
- warp/examples/benchmarks/benchmark_cloth.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
- warp/examples/benchmarks/benchmark_gemm.py +82 -48
- warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
- warp/examples/benchmarks/benchmark_launches.py +14 -6
- warp/examples/browse.py +14 -6
- warp/examples/core/example_cupy.py +14 -6
- warp/examples/core/example_dem.py +14 -6
- warp/examples/core/example_fluid.py +14 -6
- warp/examples/core/example_graph_capture.py +14 -6
- warp/examples/core/example_marching_cubes.py +15 -7
- warp/examples/core/example_mesh.py +15 -7
- warp/examples/core/example_mesh_intersect.py +14 -6
- warp/examples/core/example_nvdb.py +14 -6
- warp/examples/core/example_raycast.py +14 -6
- warp/examples/core/example_raymarch.py +14 -6
- warp/examples/core/example_render_opengl.py +14 -6
- warp/examples/core/example_sph.py +14 -6
- warp/examples/core/example_torch.py +14 -6
- warp/examples/core/example_wave.py +15 -7
- warp/examples/fem/example_adaptive_grid.py +14 -6
- warp/examples/fem/example_apic_fluid.py +14 -6
- warp/examples/fem/example_burgers.py +14 -6
- warp/examples/fem/example_convection_diffusion.py +14 -6
- warp/examples/fem/example_convection_diffusion_dg.py +14 -6
- warp/examples/fem/example_deformed_geometry.py +14 -6
- warp/examples/fem/example_diffusion.py +14 -6
- warp/examples/fem/example_diffusion_3d.py +14 -6
- warp/examples/fem/example_diffusion_mgpu.py +14 -6
- warp/examples/fem/example_distortion_energy.py +14 -6
- warp/examples/fem/example_magnetostatics.py +14 -6
- warp/examples/fem/example_mixed_elasticity.py +14 -6
- warp/examples/fem/example_navier_stokes.py +14 -6
- warp/examples/fem/example_nonconforming_contact.py +14 -6
- warp/examples/fem/example_stokes.py +14 -6
- warp/examples/fem/example_stokes_transfer.py +14 -6
- warp/examples/fem/example_streamlines.py +14 -6
- warp/examples/fem/utils.py +15 -0
- warp/examples/optim/example_bounce.py +14 -6
- warp/examples/optim/example_cloth_throw.py +14 -6
- warp/examples/optim/example_diffray.py +14 -6
- warp/examples/optim/example_drone.py +14 -6
- warp/examples/optim/example_inverse_kinematics.py +14 -6
- warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
- warp/examples/optim/example_softbody_properties.py +14 -6
- warp/examples/optim/example_spring_cage.py +14 -6
- warp/examples/optim/example_trajectory.py +14 -6
- warp/examples/optim/example_walker.py +14 -6
- warp/examples/sim/example_cartpole.py +14 -6
- warp/examples/sim/example_cloth.py +14 -6
- warp/examples/sim/example_cloth_self_contact.py +95 -33
- warp/examples/sim/example_granular.py +14 -6
- warp/examples/sim/example_granular_collision_sdf.py +14 -6
- warp/examples/sim/example_jacobian_ik.py +14 -6
- warp/examples/sim/example_particle_chain.py +14 -6
- warp/examples/sim/example_quadruped.py +14 -6
- warp/examples/sim/example_rigid_chain.py +14 -6
- warp/examples/sim/example_rigid_contact.py +14 -6
- warp/examples/sim/example_rigid_force.py +14 -6
- warp/examples/sim/example_rigid_gyroscopic.py +14 -6
- warp/examples/sim/example_rigid_soft_contact.py +14 -6
- warp/examples/sim/example_soft_body.py +14 -6
- warp/examples/tile/example_tile_cholesky.py +14 -6
- warp/examples/tile/example_tile_convolution.py +14 -6
- warp/examples/tile/example_tile_fft.py +14 -6
- warp/examples/tile/example_tile_filtering.py +14 -6
- warp/examples/tile/example_tile_matmul.py +14 -6
- warp/examples/tile/example_tile_mlp.py +14 -6
- warp/examples/tile/example_tile_nbody.py +40 -21
- warp/examples/tile/example_tile_walker.py +14 -6
- warp/fabric.py +15 -0
- warp/fem/__init__.py +15 -0
- warp/fem/adaptivity.py +15 -0
- warp/fem/cache.py +15 -0
- warp/fem/dirichlet.py +15 -0
- warp/fem/domain.py +15 -0
- warp/fem/field/__init__.py +15 -0
- warp/fem/field/field.py +15 -0
- warp/fem/field/nodal_field.py +15 -0
- warp/fem/field/restriction.py +15 -0
- warp/fem/field/virtual.py +15 -0
- warp/fem/geometry/__init__.py +15 -0
- warp/fem/geometry/adaptive_nanogrid.py +15 -0
- warp/fem/geometry/closest_point.py +15 -0
- warp/fem/geometry/deformed_geometry.py +15 -0
- warp/fem/geometry/element.py +15 -0
- warp/fem/geometry/geometry.py +15 -0
- warp/fem/geometry/grid_2d.py +15 -0
- warp/fem/geometry/grid_3d.py +15 -0
- warp/fem/geometry/hexmesh.py +15 -0
- warp/fem/geometry/nanogrid.py +15 -0
- warp/fem/geometry/partition.py +15 -0
- warp/fem/geometry/quadmesh.py +15 -0
- warp/fem/geometry/tetmesh.py +15 -0
- warp/fem/geometry/trimesh.py +15 -0
- warp/fem/integrate.py +15 -0
- warp/fem/linalg.py +15 -0
- warp/fem/operator.py +15 -0
- warp/fem/polynomial.py +15 -0
- warp/fem/quadrature/__init__.py +15 -0
- warp/fem/quadrature/pic_quadrature.py +15 -0
- warp/fem/quadrature/quadrature.py +15 -0
- warp/fem/space/__init__.py +15 -0
- warp/fem/space/basis_function_space.py +15 -0
- warp/fem/space/basis_space.py +15 -0
- warp/fem/space/dof_mapper.py +15 -0
- warp/fem/space/function_space.py +15 -0
- warp/fem/space/grid_2d_function_space.py +15 -0
- warp/fem/space/grid_3d_function_space.py +15 -0
- warp/fem/space/hexmesh_function_space.py +15 -0
- warp/fem/space/nanogrid_function_space.py +15 -0
- warp/fem/space/partition.py +15 -0
- warp/fem/space/quadmesh_function_space.py +15 -0
- warp/fem/space/restriction.py +15 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +15 -0
- warp/fem/space/shape/shape_function.py +15 -0
- warp/fem/space/shape/square_shape_function.py +15 -0
- warp/fem/space/shape/tet_shape_function.py +15 -0
- warp/fem/space/shape/triangle_shape_function.py +15 -0
- warp/fem/space/tetmesh_function_space.py +15 -0
- warp/fem/space/topology.py +15 -0
- warp/fem/space/trimesh_function_space.py +15 -0
- warp/fem/types.py +15 -0
- warp/fem/utils.py +15 -0
- warp/jax.py +14 -6
- warp/jax_experimental.py +14 -6
- warp/math.py +14 -6
- warp/native/array.h +15 -6
- warp/native/builtin.h +15 -6
- warp/native/bvh.cpp +15 -6
- warp/native/bvh.cu +15 -6
- warp/native/bvh.h +15 -6
- warp/native/clang/clang.cpp +16 -7
- warp/native/coloring.cpp +15 -6
- warp/native/crt.cpp +15 -6
- warp/native/crt.h +16 -6
- warp/native/cuda_crt.h +15 -6
- warp/native/cuda_util.cpp +15 -6
- warp/native/cuda_util.h +15 -6
- warp/native/cutlass_gemm.cpp +15 -6
- warp/native/cutlass_gemm.cu +16 -7
- warp/native/error.cpp +15 -6
- warp/native/error.h +15 -6
- warp/native/exports.h +17 -0
- warp/native/fabric.h +15 -6
- warp/native/hashgrid.cpp +15 -6
- warp/native/hashgrid.cu +15 -6
- warp/native/hashgrid.h +15 -6
- warp/native/initializer_array.h +15 -6
- warp/native/intersect.h +15 -6
- warp/native/intersect_adj.h +15 -6
- warp/native/intersect_tri.h +17 -0
- warp/native/marching.cpp +16 -0
- warp/native/marching.cu +15 -6
- warp/native/marching.h +17 -0
- warp/native/mat.h +31 -9
- warp/native/mathdx.cpp +15 -6
- warp/native/matnn.h +15 -6
- warp/native/mesh.cpp +15 -6
- warp/native/mesh.cu +15 -6
- warp/native/mesh.h +15 -6
- warp/native/noise.h +15 -6
- warp/native/quat.h +15 -6
- warp/native/rand.h +15 -6
- warp/native/range.h +15 -6
- warp/native/reduce.cpp +15 -6
- warp/native/reduce.cu +15 -6
- warp/native/runlength_encode.cpp +15 -6
- warp/native/runlength_encode.cu +15 -6
- warp/native/scan.cpp +15 -6
- warp/native/scan.cu +15 -6
- warp/native/scan.h +15 -6
- warp/native/solid_angle.h +17 -0
- warp/native/sort.cpp +15 -6
- warp/native/sort.cu +15 -6
- warp/native/sort.h +15 -6
- warp/native/sparse.cpp +15 -6
- warp/native/sparse.cu +15 -6
- warp/native/spatial.h +15 -6
- warp/native/svd.h +15 -6
- warp/native/temp_buffer.h +15 -6
- warp/native/tile.h +27 -14
- warp/native/tile_reduce.h +15 -6
- warp/native/vec.h +15 -6
- warp/native/volume.cpp +15 -6
- warp/native/volume.cu +15 -6
- warp/native/volume.h +15 -6
- warp/native/volume_builder.cu +15 -6
- warp/native/volume_builder.h +15 -6
- warp/native/volume_impl.h +15 -6
- warp/native/warp.cpp +15 -6
- warp/native/warp.cu +15 -6
- warp/native/warp.h +15 -6
- warp/optim/__init__.py +14 -6
- warp/optim/adam.py +14 -6
- warp/optim/linear.py +15 -0
- warp/optim/sgd.py +14 -6
- warp/paddle.py +14 -6
- warp/render/__init__.py +14 -6
- warp/render/render_opengl.py +37 -21
- warp/render/render_usd.py +24 -8
- warp/render/utils.py +14 -6
- warp/sim/__init__.py +14 -7
- warp/sim/articulation.py +14 -6
- warp/sim/collide.py +43 -22
- warp/sim/graph_coloring.py +14 -6
- warp/sim/import_mjcf.py +14 -7
- warp/sim/import_snu.py +14 -7
- warp/sim/import_urdf.py +34 -11
- warp/sim/import_usd.py +14 -7
- warp/sim/inertia.py +14 -6
- warp/sim/integrator.py +14 -6
- warp/sim/integrator_euler.py +14 -6
- warp/sim/integrator_featherstone.py +18 -17
- warp/sim/integrator_vbd.py +15 -6
- warp/sim/integrator_xpbd.py +14 -6
- warp/sim/model.py +76 -65
- warp/sim/particles.py +14 -6
- warp/sim/render.py +16 -8
- warp/sim/utils.py +15 -0
- warp/sparse.py +15 -0
- warp/stubs.py +16 -1
- warp/tape.py +14 -6
- warp/tests/__main__.py +15 -0
- warp/tests/aux_test_class_kernel.py +14 -6
- warp/tests/aux_test_compile_consts_dummy.py +14 -6
- warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
- warp/tests/aux_test_dependent.py +14 -6
- warp/tests/aux_test_grad_customs.py +14 -6
- warp/tests/aux_test_instancing_gc.py +14 -6
- warp/tests/aux_test_module_unload.py +14 -6
- warp/tests/aux_test_name_clash1.py +14 -6
- warp/tests/aux_test_name_clash2.py +14 -6
- warp/tests/aux_test_unresolved_func.py +14 -6
- warp/tests/aux_test_unresolved_symbol.py +14 -6
- warp/tests/disabled_kinematics.py +14 -6
- warp/tests/flaky_test_sim_grad.py +14 -6
- warp/tests/run_coverage_serial.py +14 -6
- warp/tests/test_adam.py +14 -6
- warp/tests/test_arithmetic.py +14 -6
- warp/tests/test_array.py +40 -6
- warp/tests/test_array_reduce.py +14 -6
- warp/tests/test_assert.py +14 -6
- warp/tests/test_async.py +14 -6
- warp/tests/test_atomic.py +14 -6
- warp/tests/test_bool.py +14 -6
- warp/tests/test_builtins_resolution.py +14 -6
- warp/tests/test_bvh.py +14 -6
- warp/tests/test_closest_point_edge_edge.py +14 -6
- warp/tests/test_codegen.py +14 -6
- warp/tests/test_codegen_instancing.py +14 -6
- warp/tests/test_collision.py +20 -12
- warp/tests/test_coloring.py +14 -7
- warp/tests/test_compile_consts.py +14 -6
- warp/tests/test_conditional.py +14 -6
- warp/tests/test_context.py +14 -6
- warp/tests/test_copy.py +14 -6
- warp/tests/test_ctypes.py +14 -6
- warp/tests/test_dense.py +14 -6
- warp/tests/test_devices.py +14 -6
- warp/tests/test_dlpack.py +14 -6
- warp/tests/test_examples.py +21 -7
- warp/tests/test_fabricarray.py +14 -6
- warp/tests/test_fast_math.py +14 -6
- warp/tests/test_fem.py +14 -6
- warp/tests/test_fp16.py +14 -6
- warp/tests/test_func.py +14 -6
- warp/tests/test_future_annotations.py +14 -6
- warp/tests/test_generics.py +14 -6
- warp/tests/test_grad.py +14 -6
- warp/tests/test_grad_customs.py +14 -6
- warp/tests/test_grad_debug.py +14 -6
- warp/tests/test_hash_grid.py +14 -6
- warp/tests/test_implicit_init.py +14 -6
- warp/tests/test_import.py +14 -6
- warp/tests/test_indexedarray.py +14 -6
- warp/tests/test_intersect.py +14 -6
- warp/tests/test_ipc.py +14 -6
- warp/tests/test_iter.py +14 -6
- warp/tests/test_jax.py +14 -6
- warp/tests/test_large.py +14 -6
- warp/tests/test_launch.py +91 -32
- warp/tests/test_lerp.py +14 -6
- warp/tests/test_linear_solvers.py +15 -0
- warp/tests/test_lvalue.py +14 -6
- warp/tests/test_marching_cubes.py +14 -6
- warp/tests/test_mat.py +89 -7
- warp/tests/test_mat_lite.py +14 -6
- warp/tests/test_mat_scalar_ops.py +14 -6
- warp/tests/test_math.py +14 -6
- warp/tests/test_matmul.py +14 -6
- warp/tests/test_matmul_lite.py +14 -6
- warp/tests/test_mempool.py +14 -6
- warp/tests/test_mesh.py +14 -6
- warp/tests/test_mesh_query_aabb.py +14 -6
- warp/tests/test_mesh_query_point.py +14 -6
- warp/tests/test_mesh_query_ray.py +14 -6
- warp/tests/test_mlp.py +14 -6
- warp/tests/test_model.py +14 -6
- warp/tests/test_module_hashing.py +14 -6
- warp/tests/test_modules_lite.py +14 -6
- warp/tests/test_multigpu.py +14 -6
- warp/tests/test_noise.py +14 -6
- warp/tests/test_operators.py +14 -6
- warp/tests/test_options.py +14 -6
- warp/tests/test_overwrite.py +19 -3
- warp/tests/test_paddle.py +14 -6
- warp/tests/test_peer.py +14 -6
- warp/tests/test_pinned.py +14 -6
- warp/tests/test_print.py +14 -6
- warp/tests/test_quat.py +14 -6
- warp/tests/test_rand.py +14 -6
- warp/tests/test_reload.py +14 -6
- warp/tests/test_rounding.py +14 -6
- warp/tests/test_runlength_encode.py +14 -6
- warp/tests/test_scalar_ops.py +14 -6
- warp/tests/test_sim_grad_bounce_linear.py +14 -6
- warp/tests/test_sim_kinematics.py +14 -6
- warp/tests/test_smoothstep.py +14 -6
- warp/tests/test_snippet.py +15 -0
- warp/tests/test_sparse.py +14 -6
- warp/tests/test_spatial.py +14 -6
- warp/tests/test_special_values.py +14 -6
- warp/tests/test_static.py +14 -6
- warp/tests/test_streams.py +14 -6
- warp/tests/test_struct.py +14 -6
- warp/tests/test_tape.py +14 -6
- warp/tests/test_tile.py +14 -6
- warp/tests/test_tile_load.py +58 -7
- warp/tests/test_tile_mathdx.py +14 -6
- warp/tests/test_tile_mlp.py +14 -6
- warp/tests/test_tile_reduce.py +14 -6
- warp/tests/test_tile_shared_memory.py +14 -6
- warp/tests/test_tile_view.py +14 -6
- warp/tests/test_torch.py +14 -6
- warp/tests/test_transient_module.py +14 -6
- warp/tests/test_triangle_closest_point.py +14 -6
- warp/tests/test_types.py +14 -6
- warp/tests/test_utils.py +14 -6
- warp/tests/test_vbd.py +14 -6
- warp/tests/test_vec.py +14 -6
- warp/tests/test_vec_lite.py +14 -6
- warp/tests/test_vec_scalar_ops.py +14 -6
- warp/tests/test_verify_fp.py +14 -6
- warp/tests/test_volume.py +14 -6
- warp/tests/test_volume_write.py +14 -6
- warp/tests/unittest_serial.py +14 -6
- warp/tests/unittest_suites.py +14 -6
- warp/tests/unittest_utils.py +14 -6
- warp/tests/unused_test_misc.py +14 -6
- warp/tests/walkthrough_debug.py +14 -6
- warp/thirdparty/unittest_parallel.py +15 -7
- warp/torch.py +14 -6
- warp/types.py +80 -74
- warp/utils.py +14 -6
- warp_lang-1.6.2.dist-info/LICENSE.md +202 -0
- {warp_lang-1.6.0.dist-info → warp_lang-1.6.2.dist-info}/METADATA +44 -22
- warp_lang-1.6.2.dist-info/RECORD +419 -0
- {warp_lang-1.6.0.dist-info → warp_lang-1.6.2.dist-info}/WHEEL +1 -1
- warp_lang-1.6.0.dist-info/LICENSE.md +0 -126
- warp_lang-1.6.0.dist-info/RECORD +0 -419
- {warp_lang-1.6.0.dist-info → warp_lang-1.6.2.dist-info}/top_level.txt +0 -0
warp/context.py
CHANGED
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2022 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
from __future__ import annotations
|
|
9
17
|
|
|
@@ -34,6 +42,7 @@ import warp
|
|
|
34
42
|
import warp.build
|
|
35
43
|
import warp.codegen
|
|
36
44
|
import warp.config
|
|
45
|
+
from warp.types import launch_bounds_t
|
|
37
46
|
|
|
38
47
|
# represents either a built-in or user-defined function
|
|
39
48
|
|
|
@@ -5187,8 +5196,23 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
|
|
|
5187
5196
|
# represents all data required for a kernel launch
|
|
5188
5197
|
# so that launches can be replayed quickly, use `wp.launch(..., record_cmd=True)`
|
|
5189
5198
|
class Launch:
|
|
5199
|
+
"""Represents all data required for a kernel launch so that launches can be replayed quickly.
|
|
5200
|
+
|
|
5201
|
+
Users should not directly instantiate this class, instead use
|
|
5202
|
+
``wp.launch(..., record_cmd=True)`` to record a launch.
|
|
5203
|
+
"""
|
|
5204
|
+
|
|
5190
5205
|
def __init__(
|
|
5191
|
-
self,
|
|
5206
|
+
self,
|
|
5207
|
+
kernel,
|
|
5208
|
+
device: Device,
|
|
5209
|
+
hooks: Optional[KernelHooks] = None,
|
|
5210
|
+
params: Optional[Sequence[Any]] = None,
|
|
5211
|
+
params_addr: Optional[Sequence[ctypes.c_void_p]] = None,
|
|
5212
|
+
bounds: Optional[launch_bounds_t] = None,
|
|
5213
|
+
max_blocks: int = 0,
|
|
5214
|
+
block_dim: int = 256,
|
|
5215
|
+
adjoint: bool = False,
|
|
5192
5216
|
):
|
|
5193
5217
|
# retain the module executable so it doesn't get unloaded
|
|
5194
5218
|
self.module_exec = kernel.module.load(device)
|
|
@@ -5201,13 +5225,14 @@ class Launch:
|
|
|
5201
5225
|
|
|
5202
5226
|
# if not specified set a zero bound
|
|
5203
5227
|
if not bounds:
|
|
5204
|
-
bounds =
|
|
5228
|
+
bounds = launch_bounds_t(0)
|
|
5205
5229
|
|
|
5206
5230
|
# if not specified then build a list of default value params for args
|
|
5207
5231
|
if not params:
|
|
5208
5232
|
params = []
|
|
5209
5233
|
params.append(bounds)
|
|
5210
5234
|
|
|
5235
|
+
# Pack forward parameters
|
|
5211
5236
|
for a in kernel.adj.args:
|
|
5212
5237
|
if isinstance(a.type, warp.types.array):
|
|
5213
5238
|
params.append(a.type.__ctype__())
|
|
@@ -5216,6 +5241,18 @@ class Launch:
|
|
|
5216
5241
|
else:
|
|
5217
5242
|
params.append(pack_arg(kernel, a.type, a.label, 0, device, False))
|
|
5218
5243
|
|
|
5244
|
+
# Pack adjoint parameters if adjoint=True
|
|
5245
|
+
if adjoint:
|
|
5246
|
+
for a in kernel.adj.args:
|
|
5247
|
+
if isinstance(a.type, warp.types.array):
|
|
5248
|
+
params.append(a.type.__ctype__())
|
|
5249
|
+
elif isinstance(a.type, warp.codegen.Struct):
|
|
5250
|
+
params.append(a.type().__ctype__())
|
|
5251
|
+
else:
|
|
5252
|
+
# For primitive types in adjoint mode, initialize with 0
|
|
5253
|
+
params.append(pack_arg(kernel, a.type, a.label, 0, device, True))
|
|
5254
|
+
|
|
5255
|
+
# Create array of parameter addresses
|
|
5219
5256
|
kernel_args = [ctypes.c_void_p(ctypes.addressof(x)) for x in params]
|
|
5220
5257
|
kernel_params = (ctypes.c_void_p * len(kernel_args))(*kernel_args)
|
|
5221
5258
|
|
|
@@ -5225,13 +5262,30 @@ class Launch:
|
|
|
5225
5262
|
self.hooks = hooks
|
|
5226
5263
|
self.params = params
|
|
5227
5264
|
self.params_addr = params_addr
|
|
5228
|
-
self.device = device
|
|
5229
|
-
|
|
5230
|
-
|
|
5231
|
-
|
|
5265
|
+
self.device: Device = device
|
|
5266
|
+
"""The device to launch on.
|
|
5267
|
+
This should not be changed after the launch object is created.
|
|
5268
|
+
"""
|
|
5269
|
+
|
|
5270
|
+
self.bounds: launch_bounds_t = bounds
|
|
5271
|
+
"""The launch bounds. Update with :meth:`set_dim`."""
|
|
5272
|
+
|
|
5273
|
+
self.max_blocks: int = max_blocks
|
|
5274
|
+
"""The maximum number of CUDA thread blocks to use."""
|
|
5275
|
+
|
|
5276
|
+
self.block_dim: int = block_dim
|
|
5277
|
+
"""The number of threads per block."""
|
|
5232
5278
|
|
|
5233
|
-
|
|
5234
|
-
|
|
5279
|
+
self.adjoint: bool = adjoint
|
|
5280
|
+
"""Whether to run the adjoint kernel instead of the forward kernel."""
|
|
5281
|
+
|
|
5282
|
+
def set_dim(self, dim: Union[int, List[int], Tuple[int, ...]]):
|
|
5283
|
+
"""Set the launch dimensions.
|
|
5284
|
+
|
|
5285
|
+
Args:
|
|
5286
|
+
dim: The dimensions of the launch.
|
|
5287
|
+
"""
|
|
5288
|
+
self.bounds = launch_bounds_t(dim)
|
|
5235
5289
|
|
|
5236
5290
|
# launch bounds always at index 0
|
|
5237
5291
|
self.params[0] = self.bounds
|
|
@@ -5240,22 +5294,36 @@ class Launch:
|
|
|
5240
5294
|
if self.params_addr:
|
|
5241
5295
|
self.params_addr[0] = ctypes.c_void_p(ctypes.addressof(self.bounds))
|
|
5242
5296
|
|
|
5243
|
-
|
|
5244
|
-
|
|
5297
|
+
def set_param_at_index(self, index: int, value: Any, adjoint: bool = False):
|
|
5298
|
+
"""Set a kernel parameter at an index.
|
|
5299
|
+
|
|
5300
|
+
Args:
|
|
5301
|
+
index: The index of the param to set.
|
|
5302
|
+
value: The value to set the param to.
|
|
5303
|
+
"""
|
|
5245
5304
|
arg_type = self.kernel.adj.args[index].type
|
|
5246
5305
|
arg_name = self.kernel.adj.args[index].label
|
|
5247
5306
|
|
|
5248
|
-
carg = pack_arg(self.kernel, arg_type, arg_name, value, self.device,
|
|
5307
|
+
carg = pack_arg(self.kernel, arg_type, arg_name, value, self.device, adjoint)
|
|
5308
|
+
|
|
5309
|
+
if adjoint:
|
|
5310
|
+
params_index = index + len(self.kernel.adj.args) + 1
|
|
5311
|
+
else:
|
|
5312
|
+
params_index = index + 1
|
|
5249
5313
|
|
|
5250
|
-
self.params[
|
|
5314
|
+
self.params[params_index] = carg
|
|
5251
5315
|
|
|
5252
5316
|
# for CUDA kernels we need to update the address to each arg
|
|
5253
5317
|
if self.params_addr:
|
|
5254
|
-
self.params_addr[
|
|
5318
|
+
self.params_addr[params_index] = ctypes.c_void_p(ctypes.addressof(carg))
|
|
5255
5319
|
|
|
5256
|
-
|
|
5257
|
-
|
|
5258
|
-
|
|
5320
|
+
def set_param_at_index_from_ctype(self, index: int, value: Union[ctypes.Structure, int, float]):
|
|
5321
|
+
"""Set a kernel parameter at an index without any type conversion.
|
|
5322
|
+
|
|
5323
|
+
Args:
|
|
5324
|
+
index: The index of the param to set.
|
|
5325
|
+
value: The value to set the param to.
|
|
5326
|
+
"""
|
|
5259
5327
|
if isinstance(value, ctypes.Structure):
|
|
5260
5328
|
# not sure how to directly assign struct->struct without reallocating using ctypes
|
|
5261
5329
|
self.params[index + 1] = value
|
|
@@ -5267,32 +5335,62 @@ class Launch:
|
|
|
5267
5335
|
else:
|
|
5268
5336
|
self.params[index + 1].__init__(value)
|
|
5269
5337
|
|
|
5270
|
-
|
|
5271
|
-
|
|
5338
|
+
def set_param_by_name(self, name: str, value: Any, adjoint: bool = False):
|
|
5339
|
+
"""Set a kernel parameter by argument name.
|
|
5340
|
+
|
|
5341
|
+
Args:
|
|
5342
|
+
name: The name of the argument to set.
|
|
5343
|
+
value: The value to set the argument to.
|
|
5344
|
+
adjoint: If ``True``, set the adjoint of this parameter instead of the forward parameter.
|
|
5345
|
+
"""
|
|
5272
5346
|
for i, arg in enumerate(self.kernel.adj.args):
|
|
5273
5347
|
if arg.label == name:
|
|
5274
|
-
self.set_param_at_index(i, value)
|
|
5348
|
+
self.set_param_at_index(i, value, adjoint)
|
|
5349
|
+
return
|
|
5350
|
+
|
|
5351
|
+
raise ValueError(f"Argument '{name}' not found in kernel '{self.kernel.key}'")
|
|
5352
|
+
|
|
5353
|
+
def set_param_by_name_from_ctype(self, name: str, value: ctypes.Structure):
|
|
5354
|
+
"""Set a kernel parameter by argument name with no type conversions.
|
|
5275
5355
|
|
|
5276
|
-
|
|
5277
|
-
|
|
5356
|
+
Args:
|
|
5357
|
+
name: The name of the argument to set.
|
|
5358
|
+
value: The value to set the argument to.
|
|
5359
|
+
"""
|
|
5278
5360
|
# lookup argument index
|
|
5279
5361
|
for i, arg in enumerate(self.kernel.adj.args):
|
|
5280
5362
|
if arg.label == name:
|
|
5281
5363
|
self.set_param_at_index_from_ctype(i, value)
|
|
5282
5364
|
|
|
5283
|
-
|
|
5284
|
-
|
|
5365
|
+
def set_params(self, values: Sequence[Any]):
|
|
5366
|
+
"""Set all parameters.
|
|
5367
|
+
|
|
5368
|
+
Args:
|
|
5369
|
+
values: A list of values to set the params to.
|
|
5370
|
+
"""
|
|
5285
5371
|
for i, v in enumerate(values):
|
|
5286
5372
|
self.set_param_at_index(i, v)
|
|
5287
5373
|
|
|
5288
|
-
|
|
5289
|
-
|
|
5374
|
+
def set_params_from_ctypes(self, values: Sequence[ctypes.Structure]):
|
|
5375
|
+
"""Set all parameters without performing type-conversions.
|
|
5376
|
+
|
|
5377
|
+
Args:
|
|
5378
|
+
values: A list of ctypes or basic int / float types.
|
|
5379
|
+
"""
|
|
5290
5380
|
for i, v in enumerate(values):
|
|
5291
5381
|
self.set_param_at_index_from_ctype(i, v)
|
|
5292
5382
|
|
|
5293
|
-
def launch(self, stream=None) ->
|
|
5383
|
+
def launch(self, stream: Optional[Stream] = None) -> None:
|
|
5384
|
+
"""Launch the kernel.
|
|
5385
|
+
|
|
5386
|
+
Args:
|
|
5387
|
+
stream: The stream to launch on.
|
|
5388
|
+
"""
|
|
5294
5389
|
if self.device.is_cpu:
|
|
5295
|
-
self.
|
|
5390
|
+
if self.adjoint:
|
|
5391
|
+
self.hooks.backward(*self.params)
|
|
5392
|
+
else:
|
|
5393
|
+
self.hooks.forward(*self.params)
|
|
5296
5394
|
else:
|
|
5297
5395
|
if stream is None:
|
|
5298
5396
|
stream = self.device.stream
|
|
@@ -5305,32 +5403,44 @@ class Launch:
|
|
|
5305
5403
|
if graph is not None:
|
|
5306
5404
|
graph.retain_module_exec(self.module_exec)
|
|
5307
5405
|
|
|
5308
|
-
|
|
5309
|
-
|
|
5310
|
-
|
|
5311
|
-
|
|
5312
|
-
|
|
5313
|
-
|
|
5314
|
-
|
|
5315
|
-
|
|
5316
|
-
|
|
5317
|
-
|
|
5406
|
+
if self.adjoint:
|
|
5407
|
+
runtime.core.cuda_launch_kernel(
|
|
5408
|
+
self.device.context,
|
|
5409
|
+
self.hooks.backward,
|
|
5410
|
+
self.bounds.size,
|
|
5411
|
+
self.max_blocks,
|
|
5412
|
+
self.block_dim,
|
|
5413
|
+
self.hooks.backward_smem_bytes,
|
|
5414
|
+
self.params_addr,
|
|
5415
|
+
stream.cuda_stream,
|
|
5416
|
+
)
|
|
5417
|
+
else:
|
|
5418
|
+
runtime.core.cuda_launch_kernel(
|
|
5419
|
+
self.device.context,
|
|
5420
|
+
self.hooks.forward,
|
|
5421
|
+
self.bounds.size,
|
|
5422
|
+
self.max_blocks,
|
|
5423
|
+
self.block_dim,
|
|
5424
|
+
self.hooks.forward_smem_bytes,
|
|
5425
|
+
self.params_addr,
|
|
5426
|
+
stream.cuda_stream,
|
|
5427
|
+
)
|
|
5318
5428
|
|
|
5319
5429
|
|
|
5320
5430
|
def launch(
|
|
5321
5431
|
kernel,
|
|
5322
|
-
dim:
|
|
5432
|
+
dim: Union[int, Sequence[int]],
|
|
5323
5433
|
inputs: Sequence = [],
|
|
5324
5434
|
outputs: Sequence = [],
|
|
5325
5435
|
adj_inputs: Sequence = [],
|
|
5326
5436
|
adj_outputs: Sequence = [],
|
|
5327
5437
|
device: Devicelike = None,
|
|
5328
|
-
stream: Stream = None,
|
|
5329
|
-
adjoint=False,
|
|
5330
|
-
record_tape=True,
|
|
5331
|
-
record_cmd=False,
|
|
5332
|
-
max_blocks=0,
|
|
5333
|
-
block_dim=256,
|
|
5438
|
+
stream: Optional[Stream] = None,
|
|
5439
|
+
adjoint: bool = False,
|
|
5440
|
+
record_tape: bool = True,
|
|
5441
|
+
record_cmd: bool = False,
|
|
5442
|
+
max_blocks: int = 0,
|
|
5443
|
+
block_dim: int = 256,
|
|
5334
5444
|
):
|
|
5335
5445
|
"""Launch a Warp kernel on the target device
|
|
5336
5446
|
|
|
@@ -5338,18 +5448,23 @@ def launch(
|
|
|
5338
5448
|
|
|
5339
5449
|
Args:
|
|
5340
5450
|
kernel: The name of a Warp kernel function, decorated with the ``@wp.kernel`` decorator
|
|
5341
|
-
dim: The number of threads to launch the kernel, can be an integer
|
|
5451
|
+
dim: The number of threads to launch the kernel, can be an integer or a
|
|
5452
|
+
sequence of integers with a maximum of 4 dimensions.
|
|
5342
5453
|
inputs: The input parameters to the kernel (optional)
|
|
5343
5454
|
outputs: The output parameters (optional)
|
|
5344
5455
|
adj_inputs: The adjoint inputs (optional)
|
|
5345
5456
|
adj_outputs: The adjoint outputs (optional)
|
|
5346
|
-
device: The device to launch on
|
|
5347
|
-
stream: The stream to launch on
|
|
5348
|
-
adjoint: Whether to run forward or backward pass (typically use False)
|
|
5349
|
-
record_tape: When
|
|
5350
|
-
|
|
5351
|
-
|
|
5352
|
-
|
|
5457
|
+
device: The device to launch on.
|
|
5458
|
+
stream: The stream to launch on.
|
|
5459
|
+
adjoint: Whether to run forward or backward pass (typically use ``False``).
|
|
5460
|
+
record_tape: When ``True``, the launch will be recorded the global
|
|
5461
|
+
:class:`wp.Tape() <warp.Tape>` object when present.
|
|
5462
|
+
record_cmd: When ``True``, the launch will return a :class:`Launch`
|
|
5463
|
+
object. The launch will not occur until the user calls
|
|
5464
|
+
:meth:`Launch.launch()`.
|
|
5465
|
+
max_blocks: The maximum number of CUDA thread blocks to use.
|
|
5466
|
+
Only has an effect for CUDA kernel launches.
|
|
5467
|
+
If negative or zero, the maximum hardware value will be used.
|
|
5353
5468
|
block_dim: The number of threads per block.
|
|
5354
5469
|
"""
|
|
5355
5470
|
|
|
@@ -5370,7 +5485,7 @@ def launch(
|
|
|
5370
5485
|
print(f"kernel: {kernel.key} dim: {dim} inputs: {inputs} outputs: {outputs} device: {device}")
|
|
5371
5486
|
|
|
5372
5487
|
# construct launch bounds
|
|
5373
|
-
bounds =
|
|
5488
|
+
bounds = launch_bounds_t(dim)
|
|
5374
5489
|
|
|
5375
5490
|
if bounds.size > 0:
|
|
5376
5491
|
# first param is the number of threads
|
|
@@ -5427,6 +5542,17 @@ def launch(
|
|
|
5427
5542
|
f"Failed to find backward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
|
|
5428
5543
|
)
|
|
5429
5544
|
|
|
5545
|
+
if record_cmd:
|
|
5546
|
+
launch = Launch(
|
|
5547
|
+
kernel=kernel,
|
|
5548
|
+
hooks=hooks,
|
|
5549
|
+
params=params,
|
|
5550
|
+
params_addr=None,
|
|
5551
|
+
bounds=bounds,
|
|
5552
|
+
device=device,
|
|
5553
|
+
adjoint=adjoint,
|
|
5554
|
+
)
|
|
5555
|
+
return launch
|
|
5430
5556
|
hooks.backward(*params)
|
|
5431
5557
|
|
|
5432
5558
|
else:
|
|
@@ -5437,7 +5563,13 @@ def launch(
|
|
|
5437
5563
|
|
|
5438
5564
|
if record_cmd:
|
|
5439
5565
|
launch = Launch(
|
|
5440
|
-
kernel=kernel,
|
|
5566
|
+
kernel=kernel,
|
|
5567
|
+
hooks=hooks,
|
|
5568
|
+
params=params,
|
|
5569
|
+
params_addr=None,
|
|
5570
|
+
bounds=bounds,
|
|
5571
|
+
device=device,
|
|
5572
|
+
adjoint=adjoint,
|
|
5441
5573
|
)
|
|
5442
5574
|
return launch
|
|
5443
5575
|
else:
|
|
@@ -5464,16 +5596,30 @@ def launch(
|
|
|
5464
5596
|
f"Failed to find backward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
|
|
5465
5597
|
)
|
|
5466
5598
|
|
|
5467
|
-
|
|
5468
|
-
|
|
5469
|
-
|
|
5470
|
-
|
|
5471
|
-
|
|
5472
|
-
|
|
5473
|
-
|
|
5474
|
-
|
|
5475
|
-
|
|
5476
|
-
|
|
5599
|
+
if record_cmd:
|
|
5600
|
+
launch = Launch(
|
|
5601
|
+
kernel=kernel,
|
|
5602
|
+
hooks=hooks,
|
|
5603
|
+
params=params,
|
|
5604
|
+
params_addr=kernel_params,
|
|
5605
|
+
bounds=bounds,
|
|
5606
|
+
device=device,
|
|
5607
|
+
max_blocks=max_blocks,
|
|
5608
|
+
block_dim=block_dim,
|
|
5609
|
+
adjoint=adjoint,
|
|
5610
|
+
)
|
|
5611
|
+
return launch
|
|
5612
|
+
else:
|
|
5613
|
+
runtime.core.cuda_launch_kernel(
|
|
5614
|
+
device.context,
|
|
5615
|
+
hooks.backward,
|
|
5616
|
+
bounds.size,
|
|
5617
|
+
max_blocks,
|
|
5618
|
+
block_dim,
|
|
5619
|
+
hooks.backward_smem_bytes,
|
|
5620
|
+
kernel_params,
|
|
5621
|
+
stream.cuda_stream,
|
|
5622
|
+
)
|
|
5477
5623
|
|
|
5478
5624
|
else:
|
|
5479
5625
|
if hooks.forward is None:
|
|
@@ -5493,7 +5639,6 @@ def launch(
|
|
|
5493
5639
|
block_dim=block_dim,
|
|
5494
5640
|
)
|
|
5495
5641
|
return launch
|
|
5496
|
-
|
|
5497
5642
|
else:
|
|
5498
5643
|
# launch
|
|
5499
5644
|
runtime.core.cuda_launch_kernel(
|
|
@@ -6286,6 +6431,26 @@ def export_functions_rst(file): # pragma: no cover
|
|
|
6286
6431
|
def export_stubs(file): # pragma: no cover
|
|
6287
6432
|
"""Generates stub file for auto-complete of builtin functions"""
|
|
6288
6433
|
|
|
6434
|
+
# Add copyright notice
|
|
6435
|
+
print(
|
|
6436
|
+
"""# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
6437
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
6438
|
+
#
|
|
6439
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6440
|
+
# you may not use this file except in compliance with the License.
|
|
6441
|
+
# You may obtain a copy of the License at
|
|
6442
|
+
#
|
|
6443
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6444
|
+
#
|
|
6445
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
6446
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
6447
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
6448
|
+
# See the License for the specific language governing permissions and
|
|
6449
|
+
# limitations under the License.
|
|
6450
|
+
""",
|
|
6451
|
+
file=file,
|
|
6452
|
+
)
|
|
6453
|
+
|
|
6289
6454
|
print(
|
|
6290
6455
|
"# Autogenerated file, do not edit, this file provides stubs for builtins autocomplete in VSCode, PyCharm, etc",
|
|
6291
6456
|
file=file,
|
warp/dlpack.py
CHANGED
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2023 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
# Python specification for DLpack:
|
|
9
17
|
# https://dmlc.github.io/dlpack/latest/python_spec.html
|
warp/examples/__init__.py
CHANGED
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2024 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
import os
|
|
9
17
|
|
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2022 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
import gc
|
|
9
17
|
import statistics as stats
|
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2022 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
# include parent path
|
|
9
17
|
import csv
|
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2022 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
import cupy as cp
|
|
9
17
|
import cupyx as cpx
|
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2022 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
import jax.lax
|
|
9
17
|
import jax.numpy as jnp
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import math
|
|
2
17
|
|
|
3
18
|
import cupy as cp
|
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2022 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
import numpy as np
|
|
9
17
|
|