warp-lang 1.0.2__py3-none-win_amd64.whl → 1.2.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +108 -97
- warp/__init__.pyi +1 -1
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +88 -113
- warp/build_dll.py +383 -375
- warp/builtins.py +3693 -3354
- warp/codegen.py +2925 -2792
- warp/config.py +40 -36
- warp/constants.py +49 -45
- warp/context.py +5409 -5102
- warp/dlpack.py +442 -442
- warp/examples/__init__.py +16 -16
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -92
- warp/examples/assets/nv_humanoid.xml +183 -183
- warp/examples/assets/quadruped.urdf +267 -267
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +381 -383
- warp/examples/benchmarks/benchmark_cloth.py +278 -277
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
- warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
- warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
- warp/examples/benchmarks/benchmark_cloth_warp.py +145 -146
- warp/examples/benchmarks/benchmark_launches.py +293 -295
- warp/examples/browse.py +29 -29
- warp/examples/core/example_dem.py +232 -219
- warp/examples/core/example_fluid.py +291 -267
- warp/examples/core/example_graph_capture.py +142 -126
- warp/examples/core/example_marching_cubes.py +186 -174
- warp/examples/core/example_mesh.py +172 -155
- warp/examples/core/example_mesh_intersect.py +203 -193
- warp/examples/core/example_nvdb.py +174 -170
- warp/examples/core/example_raycast.py +103 -90
- warp/examples/core/example_raymarch.py +197 -178
- warp/examples/core/example_render_opengl.py +183 -141
- warp/examples/core/example_sph.py +403 -387
- warp/examples/core/example_torch.py +219 -181
- warp/examples/core/example_wave.py +261 -248
- warp/examples/fem/bsr_utils.py +378 -380
- warp/examples/fem/example_apic_fluid.py +432 -389
- warp/examples/fem/example_burgers.py +262 -0
- warp/examples/fem/example_convection_diffusion.py +180 -168
- warp/examples/fem/example_convection_diffusion_dg.py +217 -209
- warp/examples/fem/example_deformed_geometry.py +175 -159
- warp/examples/fem/example_diffusion.py +199 -173
- warp/examples/fem/example_diffusion_3d.py +178 -152
- warp/examples/fem/example_diffusion_mgpu.py +219 -214
- warp/examples/fem/example_mixed_elasticity.py +242 -222
- warp/examples/fem/example_navier_stokes.py +257 -243
- warp/examples/fem/example_stokes.py +218 -192
- warp/examples/fem/example_stokes_transfer.py +263 -249
- warp/examples/fem/mesh_utils.py +133 -109
- warp/examples/fem/plot_utils.py +292 -287
- warp/examples/optim/example_bounce.py +258 -246
- warp/examples/optim/example_cloth_throw.py +220 -209
- warp/examples/optim/example_diffray.py +564 -536
- warp/examples/optim/example_drone.py +862 -835
- warp/examples/optim/example_inverse_kinematics.py +174 -168
- warp/examples/optim/example_inverse_kinematics_torch.py +183 -169
- warp/examples/optim/example_spring_cage.py +237 -231
- warp/examples/optim/example_trajectory.py +221 -199
- warp/examples/optim/example_walker.py +304 -293
- warp/examples/sim/example_cartpole.py +137 -129
- warp/examples/sim/example_cloth.py +194 -186
- warp/examples/sim/example_granular.py +122 -111
- warp/examples/sim/example_granular_collision_sdf.py +195 -186
- warp/examples/sim/example_jacobian_ik.py +234 -214
- warp/examples/sim/example_particle_chain.py +116 -105
- warp/examples/sim/example_quadruped.py +191 -180
- warp/examples/sim/example_rigid_chain.py +195 -187
- warp/examples/sim/example_rigid_contact.py +187 -177
- warp/examples/sim/example_rigid_force.py +125 -125
- warp/examples/sim/example_rigid_gyroscopic.py +107 -95
- warp/examples/sim/example_rigid_soft_contact.py +132 -122
- warp/examples/sim/example_soft_body.py +188 -177
- warp/fabric.py +337 -335
- warp/fem/__init__.py +61 -27
- warp/fem/cache.py +403 -388
- warp/fem/dirichlet.py +178 -179
- warp/fem/domain.py +262 -263
- warp/fem/field/__init__.py +100 -101
- warp/fem/field/field.py +148 -149
- warp/fem/field/nodal_field.py +298 -299
- warp/fem/field/restriction.py +22 -21
- warp/fem/field/test.py +180 -181
- warp/fem/field/trial.py +183 -183
- warp/fem/geometry/__init__.py +16 -19
- warp/fem/geometry/closest_point.py +69 -70
- warp/fem/geometry/deformed_geometry.py +270 -271
- warp/fem/geometry/element.py +748 -744
- warp/fem/geometry/geometry.py +184 -186
- warp/fem/geometry/grid_2d.py +380 -373
- warp/fem/geometry/grid_3d.py +437 -435
- warp/fem/geometry/hexmesh.py +953 -953
- warp/fem/geometry/nanogrid.py +455 -0
- warp/fem/geometry/partition.py +374 -376
- warp/fem/geometry/quadmesh_2d.py +532 -532
- warp/fem/geometry/tetmesh.py +840 -840
- warp/fem/geometry/trimesh_2d.py +577 -577
- warp/fem/integrate.py +1684 -1615
- warp/fem/operator.py +190 -191
- warp/fem/polynomial.py +214 -213
- warp/fem/quadrature/__init__.py +2 -2
- warp/fem/quadrature/pic_quadrature.py +243 -245
- warp/fem/quadrature/quadrature.py +295 -294
- warp/fem/space/__init__.py +179 -292
- warp/fem/space/basis_space.py +522 -489
- warp/fem/space/collocated_function_space.py +100 -105
- warp/fem/space/dof_mapper.py +236 -236
- warp/fem/space/function_space.py +148 -145
- warp/fem/space/grid_2d_function_space.py +148 -267
- warp/fem/space/grid_3d_function_space.py +167 -306
- warp/fem/space/hexmesh_function_space.py +253 -352
- warp/fem/space/nanogrid_function_space.py +202 -0
- warp/fem/space/partition.py +350 -350
- warp/fem/space/quadmesh_2d_function_space.py +261 -369
- warp/fem/space/restriction.py +161 -160
- warp/fem/space/shape/__init__.py +90 -15
- warp/fem/space/shape/cube_shape_function.py +728 -738
- warp/fem/space/shape/shape_function.py +102 -103
- warp/fem/space/shape/square_shape_function.py +611 -611
- warp/fem/space/shape/tet_shape_function.py +565 -567
- warp/fem/space/shape/triangle_shape_function.py +429 -429
- warp/fem/space/tetmesh_function_space.py +224 -292
- warp/fem/space/topology.py +297 -295
- warp/fem/space/trimesh_2d_function_space.py +153 -221
- warp/fem/types.py +77 -77
- warp/fem/utils.py +495 -495
- warp/jax.py +166 -141
- warp/jax_experimental.py +341 -339
- warp/native/array.h +1081 -1025
- warp/native/builtin.h +1603 -1560
- warp/native/bvh.cpp +402 -398
- warp/native/bvh.cu +533 -525
- warp/native/bvh.h +430 -429
- warp/native/clang/clang.cpp +496 -464
- warp/native/crt.cpp +42 -32
- warp/native/crt.h +352 -335
- warp/native/cuda_crt.h +1049 -1049
- warp/native/cuda_util.cpp +549 -540
- warp/native/cuda_util.h +288 -203
- warp/native/cutlass_gemm.cpp +34 -34
- warp/native/cutlass_gemm.cu +372 -372
- warp/native/error.cpp +66 -66
- warp/native/error.h +27 -27
- warp/native/exports.h +187 -0
- warp/native/fabric.h +228 -228
- warp/native/hashgrid.cpp +301 -278
- warp/native/hashgrid.cu +78 -77
- warp/native/hashgrid.h +227 -227
- warp/native/initializer_array.h +32 -32
- warp/native/intersect.h +1204 -1204
- warp/native/intersect_adj.h +365 -365
- warp/native/intersect_tri.h +322 -322
- warp/native/marching.cpp +2 -2
- warp/native/marching.cu +497 -497
- warp/native/marching.h +2 -2
- warp/native/mat.h +1545 -1498
- warp/native/matnn.h +333 -333
- warp/native/mesh.cpp +203 -203
- warp/native/mesh.cu +292 -293
- warp/native/mesh.h +1887 -1887
- warp/native/nanovdb/GridHandle.h +366 -0
- warp/native/nanovdb/HostBuffer.h +590 -0
- warp/native/nanovdb/NanoVDB.h +6624 -4782
- warp/native/nanovdb/PNanoVDB.h +3390 -2553
- warp/native/noise.h +850 -850
- warp/native/quat.h +1112 -1085
- warp/native/rand.h +303 -299
- warp/native/range.h +108 -108
- warp/native/reduce.cpp +156 -156
- warp/native/reduce.cu +348 -348
- warp/native/runlength_encode.cpp +61 -61
- warp/native/runlength_encode.cu +46 -46
- warp/native/scan.cpp +30 -30
- warp/native/scan.cu +36 -36
- warp/native/scan.h +7 -7
- warp/native/solid_angle.h +442 -442
- warp/native/sort.cpp +94 -94
- warp/native/sort.cu +97 -97
- warp/native/sort.h +14 -14
- warp/native/sparse.cpp +337 -337
- warp/native/sparse.cu +544 -544
- warp/native/spatial.h +630 -630
- warp/native/svd.h +562 -562
- warp/native/temp_buffer.h +30 -30
- warp/native/vec.h +1177 -1133
- warp/native/volume.cpp +529 -297
- warp/native/volume.cu +58 -32
- warp/native/volume.h +960 -538
- warp/native/volume_builder.cu +446 -425
- warp/native/volume_builder.h +34 -19
- warp/native/volume_impl.h +61 -0
- warp/native/warp.cpp +1057 -1052
- warp/native/warp.cu +2949 -2828
- warp/native/warp.h +321 -305
- warp/optim/__init__.py +9 -9
- warp/optim/adam.py +120 -120
- warp/optim/linear.py +1104 -939
- warp/optim/sgd.py +104 -92
- warp/render/__init__.py +10 -10
- warp/render/render_opengl.py +3356 -3204
- warp/render/render_usd.py +768 -749
- warp/render/utils.py +152 -150
- warp/sim/__init__.py +52 -59
- warp/sim/articulation.py +685 -685
- warp/sim/collide.py +1594 -1590
- warp/sim/import_mjcf.py +489 -481
- warp/sim/import_snu.py +220 -221
- warp/sim/import_urdf.py +536 -516
- warp/sim/import_usd.py +887 -881
- warp/sim/inertia.py +316 -317
- warp/sim/integrator.py +234 -233
- warp/sim/integrator_euler.py +1956 -1956
- warp/sim/integrator_featherstone.py +1917 -1991
- warp/sim/integrator_xpbd.py +3288 -3312
- warp/sim/model.py +4473 -4314
- warp/sim/particles.py +113 -112
- warp/sim/render.py +417 -403
- warp/sim/utils.py +413 -410
- warp/sparse.py +1289 -1227
- warp/stubs.py +2192 -2469
- warp/tape.py +1162 -225
- warp/tests/__init__.py +1 -1
- warp/tests/__main__.py +4 -4
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -105
- warp/tests/aux_test_class_kernel.py +26 -26
- warp/tests/aux_test_compile_consts_dummy.py +10 -10
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
- warp/tests/aux_test_dependent.py +20 -22
- warp/tests/aux_test_grad_customs.py +21 -23
- warp/tests/aux_test_reference.py +9 -11
- warp/tests/aux_test_reference_reference.py +8 -10
- warp/tests/aux_test_square.py +15 -17
- warp/tests/aux_test_unresolved_func.py +14 -14
- warp/tests/aux_test_unresolved_symbol.py +14 -14
- warp/tests/disabled_kinematics.py +237 -239
- warp/tests/run_coverage_serial.py +31 -31
- warp/tests/test_adam.py +155 -157
- warp/tests/test_arithmetic.py +1088 -1124
- warp/tests/test_array.py +2415 -2326
- warp/tests/test_array_reduce.py +148 -150
- warp/tests/test_async.py +666 -656
- warp/tests/test_atomic.py +139 -141
- warp/tests/test_bool.py +212 -149
- warp/tests/test_builtins_resolution.py +1290 -1292
- warp/tests/test_bvh.py +162 -171
- warp/tests/test_closest_point_edge_edge.py +227 -228
- warp/tests/test_codegen.py +562 -553
- warp/tests/test_compile_consts.py +217 -101
- warp/tests/test_conditional.py +244 -246
- warp/tests/test_copy.py +230 -215
- warp/tests/test_ctypes.py +630 -632
- warp/tests/test_dense.py +65 -67
- warp/tests/test_devices.py +89 -98
- warp/tests/test_dlpack.py +528 -529
- warp/tests/test_examples.py +403 -378
- warp/tests/test_fabricarray.py +952 -955
- warp/tests/test_fast_math.py +60 -54
- warp/tests/test_fem.py +1298 -1278
- warp/tests/test_fp16.py +128 -130
- warp/tests/test_func.py +336 -337
- warp/tests/test_generics.py +596 -571
- warp/tests/test_grad.py +885 -640
- warp/tests/test_grad_customs.py +331 -336
- warp/tests/test_hash_grid.py +208 -164
- warp/tests/test_import.py +37 -39
- warp/tests/test_indexedarray.py +1132 -1134
- warp/tests/test_intersect.py +65 -67
- warp/tests/test_jax.py +305 -307
- warp/tests/test_large.py +169 -164
- warp/tests/test_launch.py +352 -354
- warp/tests/test_lerp.py +217 -261
- warp/tests/test_linear_solvers.py +189 -171
- warp/tests/test_lvalue.py +419 -493
- warp/tests/test_marching_cubes.py +63 -65
- warp/tests/test_mat.py +1799 -1827
- warp/tests/test_mat_lite.py +113 -115
- warp/tests/test_mat_scalar_ops.py +2905 -2889
- warp/tests/test_math.py +124 -193
- warp/tests/test_matmul.py +498 -499
- warp/tests/test_matmul_lite.py +408 -410
- warp/tests/test_mempool.py +186 -190
- warp/tests/test_mesh.py +281 -324
- warp/tests/test_mesh_query_aabb.py +226 -241
- warp/tests/test_mesh_query_point.py +690 -702
- warp/tests/test_mesh_query_ray.py +290 -303
- warp/tests/test_mlp.py +274 -276
- warp/tests/test_model.py +108 -110
- warp/tests/test_module_hashing.py +111 -0
- warp/tests/test_modules_lite.py +36 -39
- warp/tests/test_multigpu.py +161 -163
- warp/tests/test_noise.py +244 -248
- warp/tests/test_operators.py +248 -250
- warp/tests/test_options.py +121 -125
- warp/tests/test_peer.py +131 -137
- warp/tests/test_pinned.py +76 -78
- warp/tests/test_print.py +52 -54
- warp/tests/test_quat.py +2084 -2086
- warp/tests/test_rand.py +324 -288
- warp/tests/test_reload.py +207 -217
- warp/tests/test_rounding.py +177 -179
- warp/tests/test_runlength_encode.py +188 -190
- warp/tests/test_sim_grad.py +241 -0
- warp/tests/test_sim_kinematics.py +89 -97
- warp/tests/test_smoothstep.py +166 -168
- warp/tests/test_snippet.py +303 -266
- warp/tests/test_sparse.py +466 -460
- warp/tests/test_spatial.py +2146 -2148
- warp/tests/test_special_values.py +362 -0
- warp/tests/test_streams.py +484 -473
- warp/tests/test_struct.py +708 -675
- warp/tests/test_tape.py +171 -148
- warp/tests/test_torch.py +741 -743
- warp/tests/test_transient_module.py +85 -87
- warp/tests/test_types.py +554 -659
- warp/tests/test_utils.py +488 -499
- warp/tests/test_vec.py +1262 -1268
- warp/tests/test_vec_lite.py +71 -73
- warp/tests/test_vec_scalar_ops.py +2097 -2099
- warp/tests/test_verify_fp.py +92 -94
- warp/tests/test_volume.py +961 -736
- warp/tests/test_volume_write.py +338 -265
- warp/tests/unittest_serial.py +38 -37
- warp/tests/unittest_suites.py +367 -359
- warp/tests/unittest_utils.py +434 -578
- warp/tests/unused_test_misc.py +69 -71
- warp/tests/walkthrough_debug.py +85 -85
- warp/thirdparty/appdirs.py +598 -598
- warp/thirdparty/dlpack.py +143 -143
- warp/thirdparty/unittest_parallel.py +563 -561
- warp/torch.py +321 -295
- warp/types.py +4941 -4450
- warp/utils.py +1008 -821
- {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/LICENSE.md +126 -126
- {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/METADATA +365 -400
- warp_lang-1.2.0.dist-info/RECORD +359 -0
- warp/examples/assets/cube.usda +0 -42
- warp/examples/assets/sphere.usda +0 -56
- warp/examples/assets/torus.usda +0 -105
- warp/examples/fem/example_convection_diffusion_dg0.py +0 -194
- warp/native/nanovdb/PNanoVDBWrite.h +0 -295
- warp_lang-1.0.2.dist-info/RECORD +0 -352
- {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/top_level.txt +0 -0
warp/tests/test_jax.py
CHANGED
|
@@ -1,307 +1,305 @@
|
|
|
1
|
-
# Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
-
# and proprietary rights in and to this software, related documentation
|
|
4
|
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
-
# distribution of this software and related documentation without an express
|
|
6
|
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
-
|
|
8
|
-
import
|
|
9
|
-
import
|
|
10
|
-
import
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
tid =
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
tid =
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
tid =
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
)
|
|
51
|
-
tid =
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
for
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
return jax.__version_info__
|
|
76
|
-
except ImportError:
|
|
77
|
-
return (0, 0, 0)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
def test_dtype_from_jax(test, device):
|
|
81
|
-
import jax.numpy as jp
|
|
82
|
-
|
|
83
|
-
def test_conversions(jax_type, warp_type):
|
|
84
|
-
test.assertEqual(wp.dtype_from_jax(jax_type), warp_type)
|
|
85
|
-
test.assertEqual(wp.dtype_from_jax(jp.dtype(jax_type)), warp_type)
|
|
86
|
-
|
|
87
|
-
test_conversions(jp.float16, wp.float16)
|
|
88
|
-
test_conversions(jp.float32, wp.float32)
|
|
89
|
-
test_conversions(jp.float64, wp.float64)
|
|
90
|
-
test_conversions(jp.int8, wp.int8)
|
|
91
|
-
test_conversions(jp.int16, wp.int16)
|
|
92
|
-
test_conversions(jp.int32, wp.int32)
|
|
93
|
-
test_conversions(jp.int64, wp.int64)
|
|
94
|
-
test_conversions(jp.uint8, wp.uint8)
|
|
95
|
-
test_conversions(jp.uint16, wp.uint16)
|
|
96
|
-
test_conversions(jp.uint32, wp.uint32)
|
|
97
|
-
test_conversions(jp.uint64, wp.uint64)
|
|
98
|
-
test_conversions(jp.bool_, wp.bool)
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
def test_dtype_to_jax(test, device):
|
|
102
|
-
import jax.numpy as jp
|
|
103
|
-
|
|
104
|
-
def test_conversions(warp_type, jax_type):
|
|
105
|
-
test.assertEqual(wp.dtype_to_jax(warp_type), jax_type)
|
|
106
|
-
|
|
107
|
-
test_conversions(wp.float16, jp.float16)
|
|
108
|
-
test_conversions(wp.float32, jp.float32)
|
|
109
|
-
test_conversions(wp.float64, jp.float64)
|
|
110
|
-
test_conversions(wp.int8, jp.int8)
|
|
111
|
-
test_conversions(wp.int16, jp.int16)
|
|
112
|
-
test_conversions(wp.int32, jp.int32)
|
|
113
|
-
test_conversions(wp.int64, jp.int64)
|
|
114
|
-
test_conversions(wp.uint8, jp.uint8)
|
|
115
|
-
test_conversions(wp.uint16, jp.uint16)
|
|
116
|
-
test_conversions(wp.uint32, jp.uint32)
|
|
117
|
-
test_conversions(wp.uint64, jp.uint64)
|
|
118
|
-
test_conversions(wp.bool, jp.bool_)
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
def test_device_conversion(test, device):
|
|
122
|
-
jax_device = wp.device_to_jax(device)
|
|
123
|
-
warp_device = wp.device_from_jax(jax_device)
|
|
124
|
-
test.assertEqual(warp_device, device)
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
@unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
|
|
128
|
-
def test_jax_kernel_basic(test, device):
|
|
129
|
-
import jax.numpy as jp
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
# get the concrete overload
|
|
166
|
-
kernel_instance = triple_kernel_scalar.
|
|
167
|
-
|
|
168
|
-
jax_triple = jax_kernel(kernel_instance)
|
|
169
|
-
|
|
170
|
-
@jax.jit
|
|
171
|
-
def f():
|
|
172
|
-
x = jp.arange(n, dtype=jp_dtype)
|
|
173
|
-
return jax_triple(x)
|
|
174
|
-
|
|
175
|
-
# run on the given device
|
|
176
|
-
with jax.default_device(wp.device_to_jax(device)):
|
|
177
|
-
y = f()
|
|
178
|
-
|
|
179
|
-
result = np.asarray(y)
|
|
180
|
-
expected = 3 * np.arange(n, dtype=np_dtype)
|
|
181
|
-
|
|
182
|
-
assert_np_equal(result, expected)
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
@unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
|
|
186
|
-
def test_jax_kernel_vecmat(test, device):
|
|
187
|
-
import jax.numpy as jp
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
jp_dtype = wp.dtype_to_jax(T._wp_scalar_type_)
|
|
193
|
-
np_dtype = wp.dtype_to_numpy(T._wp_scalar_type_)
|
|
194
|
-
|
|
195
|
-
n = 64 // T._length_
|
|
196
|
-
scalar_shape = (n, *T._shape_)
|
|
197
|
-
scalar_len = n * T._length_
|
|
198
|
-
|
|
199
|
-
with test.subTest(msg=T.__name__):
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
from warp.jax_experimental import jax_kernel
|
|
225
|
-
|
|
226
|
-
n = 64
|
|
227
|
-
|
|
228
|
-
jax_multiarg = jax_kernel(multiarg_kernel)
|
|
229
|
-
|
|
230
|
-
@jax.jit
|
|
231
|
-
def f():
|
|
232
|
-
a = jp.full(n, 1, dtype=jp.float32)
|
|
233
|
-
b = jp.full(n, 2, dtype=jp.float32)
|
|
234
|
-
c = jp.full(n, 3, dtype=jp.float32)
|
|
235
|
-
return jax_multiarg(a, b, c)
|
|
236
|
-
|
|
237
|
-
# run on the given device
|
|
238
|
-
with jax.default_device(wp.device_to_jax(device)):
|
|
239
|
-
x, y = f()
|
|
240
|
-
|
|
241
|
-
result_x, result_y = np.asarray(x), np.asarray(y)
|
|
242
|
-
expected_x = np.full(n, 3, dtype=np.float32)
|
|
243
|
-
expected_y = np.full(n, 5, dtype=np.float32)
|
|
244
|
-
|
|
245
|
-
assert_np_equal(result_x, expected_x)
|
|
246
|
-
assert_np_equal(result_y, expected_y)
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
class TestJax(unittest.TestCase):
|
|
250
|
-
pass
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
# try adding Jax tests if Jax is installed correctly
|
|
254
|
-
try:
|
|
255
|
-
# prevent Jax from gobbling up GPU memory
|
|
256
|
-
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
|
|
257
|
-
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
|
|
258
|
-
|
|
259
|
-
import jax
|
|
260
|
-
import jax.dlpack
|
|
261
|
-
|
|
262
|
-
# NOTE: we must enable 64-bit types in Jax to test the full gamut of types
|
|
263
|
-
jax.config.update("jax_enable_x64", True)
|
|
264
|
-
|
|
265
|
-
# check which Warp devices work with Jax
|
|
266
|
-
# CUDA devices may fail if Jax cannot find a CUDA Toolkit
|
|
267
|
-
test_devices = get_test_devices()
|
|
268
|
-
jax_compatible_devices = []
|
|
269
|
-
jax_compatible_cuda_devices = []
|
|
270
|
-
for d in test_devices:
|
|
271
|
-
try:
|
|
272
|
-
with jax.default_device(wp.device_to_jax(d)):
|
|
273
|
-
j = jax.numpy.arange(10, dtype=jax.numpy.float32)
|
|
274
|
-
j += 1
|
|
275
|
-
jax_compatible_devices.append(d)
|
|
276
|
-
if d.is_cuda:
|
|
277
|
-
jax_compatible_cuda_devices.append(d)
|
|
278
|
-
except Exception as e:
|
|
279
|
-
print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}")
|
|
280
|
-
|
|
281
|
-
add_function_test(TestJax, "test_dtype_from_jax", test_dtype_from_jax, devices=None)
|
|
282
|
-
add_function_test(TestJax, "test_dtype_to_jax", test_dtype_to_jax, devices=None)
|
|
283
|
-
|
|
284
|
-
if jax_compatible_devices:
|
|
285
|
-
add_function_test(TestJax, "test_device_conversion", test_device_conversion, devices=jax_compatible_devices)
|
|
286
|
-
|
|
287
|
-
if jax_compatible_cuda_devices:
|
|
288
|
-
add_function_test(
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
wp.build.clear_kernel_cache()
|
|
307
|
-
unittest.main(verbosity=2)
|
|
1
|
+
# Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
# and proprietary rights in and to this software, related documentation
|
|
4
|
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
# distribution of this software and related documentation without an express
|
|
6
|
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import unittest
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
import warp as wp
|
|
15
|
+
from warp.tests.unittest_utils import *
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# basic kernel with one input and output
|
|
19
|
+
@wp.kernel
|
|
20
|
+
def triple_kernel(input: wp.array(dtype=float), output: wp.array(dtype=float)):
|
|
21
|
+
tid = wp.tid()
|
|
22
|
+
output[tid] = 3.0 * input[tid]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# generic kernel with one scalar input and output
|
|
26
|
+
@wp.kernel
|
|
27
|
+
def triple_kernel_scalar(input: wp.array(dtype=Any), output: wp.array(dtype=Any)):
|
|
28
|
+
tid = wp.tid()
|
|
29
|
+
output[tid] = input.dtype(3) * input[tid]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# generic kernel with one vector/matrix input and output
|
|
33
|
+
@wp.kernel
|
|
34
|
+
def triple_kernel_vecmat(input: wp.array(dtype=Any), output: wp.array(dtype=Any)):
|
|
35
|
+
tid = wp.tid()
|
|
36
|
+
output[tid] = input.dtype.dtype(3) * input[tid]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# kernel with multiple inputs and outputs
|
|
40
|
+
@wp.kernel
|
|
41
|
+
def multiarg_kernel(
|
|
42
|
+
# inputs
|
|
43
|
+
a: wp.array(dtype=float),
|
|
44
|
+
b: wp.array(dtype=float),
|
|
45
|
+
c: wp.array(dtype=float),
|
|
46
|
+
# outputs
|
|
47
|
+
ab: wp.array(dtype=float),
|
|
48
|
+
bc: wp.array(dtype=float),
|
|
49
|
+
):
|
|
50
|
+
tid = wp.tid()
|
|
51
|
+
ab[tid] = a[tid] + b[tid]
|
|
52
|
+
bc[tid] = b[tid] + c[tid]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# various types for testing
|
|
56
|
+
scalar_types = wp.types.scalar_types
|
|
57
|
+
vector_types = []
|
|
58
|
+
matrix_types = []
|
|
59
|
+
for dim in [2, 3, 4]:
|
|
60
|
+
for T in scalar_types:
|
|
61
|
+
vector_types.append(wp.vec(dim, T))
|
|
62
|
+
matrix_types.append(wp.mat((dim, dim), T))
|
|
63
|
+
|
|
64
|
+
# explicitly overload generic kernels to avoid module reloading during tests
|
|
65
|
+
for T in scalar_types:
|
|
66
|
+
wp.overload(triple_kernel_scalar, [wp.array(dtype=T), wp.array(dtype=T)])
|
|
67
|
+
for T in [*vector_types, *matrix_types]:
|
|
68
|
+
wp.overload(triple_kernel_vecmat, [wp.array(dtype=T), wp.array(dtype=T)])
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _jax_version():
|
|
72
|
+
try:
|
|
73
|
+
import jax
|
|
74
|
+
|
|
75
|
+
return jax.__version_info__
|
|
76
|
+
except ImportError:
|
|
77
|
+
return (0, 0, 0)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def test_dtype_from_jax(test, device):
|
|
81
|
+
import jax.numpy as jp
|
|
82
|
+
|
|
83
|
+
def test_conversions(jax_type, warp_type):
|
|
84
|
+
test.assertEqual(wp.dtype_from_jax(jax_type), warp_type)
|
|
85
|
+
test.assertEqual(wp.dtype_from_jax(jp.dtype(jax_type)), warp_type)
|
|
86
|
+
|
|
87
|
+
test_conversions(jp.float16, wp.float16)
|
|
88
|
+
test_conversions(jp.float32, wp.float32)
|
|
89
|
+
test_conversions(jp.float64, wp.float64)
|
|
90
|
+
test_conversions(jp.int8, wp.int8)
|
|
91
|
+
test_conversions(jp.int16, wp.int16)
|
|
92
|
+
test_conversions(jp.int32, wp.int32)
|
|
93
|
+
test_conversions(jp.int64, wp.int64)
|
|
94
|
+
test_conversions(jp.uint8, wp.uint8)
|
|
95
|
+
test_conversions(jp.uint16, wp.uint16)
|
|
96
|
+
test_conversions(jp.uint32, wp.uint32)
|
|
97
|
+
test_conversions(jp.uint64, wp.uint64)
|
|
98
|
+
test_conversions(jp.bool_, wp.bool)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def test_dtype_to_jax(test, device):
|
|
102
|
+
import jax.numpy as jp
|
|
103
|
+
|
|
104
|
+
def test_conversions(warp_type, jax_type):
|
|
105
|
+
test.assertEqual(wp.dtype_to_jax(warp_type), jax_type)
|
|
106
|
+
|
|
107
|
+
test_conversions(wp.float16, jp.float16)
|
|
108
|
+
test_conversions(wp.float32, jp.float32)
|
|
109
|
+
test_conversions(wp.float64, jp.float64)
|
|
110
|
+
test_conversions(wp.int8, jp.int8)
|
|
111
|
+
test_conversions(wp.int16, jp.int16)
|
|
112
|
+
test_conversions(wp.int32, jp.int32)
|
|
113
|
+
test_conversions(wp.int64, jp.int64)
|
|
114
|
+
test_conversions(wp.uint8, jp.uint8)
|
|
115
|
+
test_conversions(wp.uint16, jp.uint16)
|
|
116
|
+
test_conversions(wp.uint32, jp.uint32)
|
|
117
|
+
test_conversions(wp.uint64, jp.uint64)
|
|
118
|
+
test_conversions(wp.bool, jp.bool_)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def test_device_conversion(test, device):
|
|
122
|
+
jax_device = wp.device_to_jax(device)
|
|
123
|
+
warp_device = wp.device_from_jax(jax_device)
|
|
124
|
+
test.assertEqual(warp_device, device)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
|
|
128
|
+
def test_jax_kernel_basic(test, device):
|
|
129
|
+
import jax.numpy as jp
|
|
130
|
+
|
|
131
|
+
from warp.jax_experimental import jax_kernel
|
|
132
|
+
|
|
133
|
+
n = 64
|
|
134
|
+
|
|
135
|
+
jax_triple = jax_kernel(triple_kernel)
|
|
136
|
+
|
|
137
|
+
@jax.jit
|
|
138
|
+
def f():
|
|
139
|
+
x = jp.arange(n, dtype=jp.float32)
|
|
140
|
+
return jax_triple(x)
|
|
141
|
+
|
|
142
|
+
# run on the given device
|
|
143
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
144
|
+
y = f()
|
|
145
|
+
|
|
146
|
+
result = np.asarray(y).reshape((n,))
|
|
147
|
+
expected = 3 * np.arange(n, dtype=np.float32)
|
|
148
|
+
|
|
149
|
+
assert_np_equal(result, expected)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
|
|
153
|
+
def test_jax_kernel_scalar(test, device):
|
|
154
|
+
import jax.numpy as jp
|
|
155
|
+
|
|
156
|
+
from warp.jax_experimental import jax_kernel
|
|
157
|
+
|
|
158
|
+
n = 64
|
|
159
|
+
|
|
160
|
+
for T in scalar_types:
|
|
161
|
+
jp_dtype = wp.dtype_to_jax(T)
|
|
162
|
+
np_dtype = wp.dtype_to_numpy(T)
|
|
163
|
+
|
|
164
|
+
with test.subTest(msg=T.__name__):
|
|
165
|
+
# get the concrete overload
|
|
166
|
+
kernel_instance = triple_kernel_scalar.add_overload([wp.array(dtype=T), wp.array(dtype=T)])
|
|
167
|
+
|
|
168
|
+
jax_triple = jax_kernel(kernel_instance)
|
|
169
|
+
|
|
170
|
+
@jax.jit
|
|
171
|
+
def f(jax_triple=jax_triple, jp_dtype=jp_dtype):
|
|
172
|
+
x = jp.arange(n, dtype=jp_dtype)
|
|
173
|
+
return jax_triple(x)
|
|
174
|
+
|
|
175
|
+
# run on the given device
|
|
176
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
177
|
+
y = f()
|
|
178
|
+
|
|
179
|
+
result = np.asarray(y).reshape((n,))
|
|
180
|
+
expected = 3 * np.arange(n, dtype=np_dtype)
|
|
181
|
+
|
|
182
|
+
assert_np_equal(result, expected)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
|
|
186
|
+
def test_jax_kernel_vecmat(test, device):
|
|
187
|
+
import jax.numpy as jp
|
|
188
|
+
|
|
189
|
+
from warp.jax_experimental import jax_kernel
|
|
190
|
+
|
|
191
|
+
for T in [*vector_types, *matrix_types]:
|
|
192
|
+
jp_dtype = wp.dtype_to_jax(T._wp_scalar_type_)
|
|
193
|
+
np_dtype = wp.dtype_to_numpy(T._wp_scalar_type_)
|
|
194
|
+
|
|
195
|
+
n = 64 // T._length_
|
|
196
|
+
scalar_shape = (n, *T._shape_)
|
|
197
|
+
scalar_len = n * T._length_
|
|
198
|
+
|
|
199
|
+
with test.subTest(msg=T.__name__):
|
|
200
|
+
# get the concrete overload
|
|
201
|
+
kernel_instance = triple_kernel_vecmat.add_overload([wp.array(dtype=T), wp.array(dtype=T)])
|
|
202
|
+
|
|
203
|
+
jax_triple = jax_kernel(kernel_instance)
|
|
204
|
+
|
|
205
|
+
@jax.jit
|
|
206
|
+
def f(jax_triple=jax_triple, jp_dtype=jp_dtype, scalar_len=scalar_len, scalar_shape=scalar_shape):
|
|
207
|
+
x = jp.arange(scalar_len, dtype=jp_dtype).reshape(scalar_shape)
|
|
208
|
+
return jax_triple(x)
|
|
209
|
+
|
|
210
|
+
# run on the given device
|
|
211
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
212
|
+
y = f()
|
|
213
|
+
|
|
214
|
+
result = np.asarray(y).reshape(scalar_shape)
|
|
215
|
+
expected = 3 * np.arange(scalar_len, dtype=np_dtype).reshape(scalar_shape)
|
|
216
|
+
|
|
217
|
+
assert_np_equal(result, expected)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
|
|
221
|
+
def test_jax_kernel_multiarg(test, device):
|
|
222
|
+
import jax.numpy as jp
|
|
223
|
+
|
|
224
|
+
from warp.jax_experimental import jax_kernel
|
|
225
|
+
|
|
226
|
+
n = 64
|
|
227
|
+
|
|
228
|
+
jax_multiarg = jax_kernel(multiarg_kernel)
|
|
229
|
+
|
|
230
|
+
@jax.jit
|
|
231
|
+
def f():
|
|
232
|
+
a = jp.full(n, 1, dtype=jp.float32)
|
|
233
|
+
b = jp.full(n, 2, dtype=jp.float32)
|
|
234
|
+
c = jp.full(n, 3, dtype=jp.float32)
|
|
235
|
+
return jax_multiarg(a, b, c)
|
|
236
|
+
|
|
237
|
+
# run on the given device
|
|
238
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
239
|
+
x, y = f()
|
|
240
|
+
|
|
241
|
+
result_x, result_y = np.asarray(x), np.asarray(y)
|
|
242
|
+
expected_x = np.full(n, 3, dtype=np.float32)
|
|
243
|
+
expected_y = np.full(n, 5, dtype=np.float32)
|
|
244
|
+
|
|
245
|
+
assert_np_equal(result_x, expected_x)
|
|
246
|
+
assert_np_equal(result_y, expected_y)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class TestJax(unittest.TestCase):
|
|
250
|
+
pass
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
# try adding Jax tests if Jax is installed correctly
|
|
254
|
+
try:
|
|
255
|
+
# prevent Jax from gobbling up GPU memory
|
|
256
|
+
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
|
|
257
|
+
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
|
|
258
|
+
|
|
259
|
+
import jax
|
|
260
|
+
import jax.dlpack
|
|
261
|
+
|
|
262
|
+
# NOTE: we must enable 64-bit types in Jax to test the full gamut of types
|
|
263
|
+
jax.config.update("jax_enable_x64", True)
|
|
264
|
+
|
|
265
|
+
# check which Warp devices work with Jax
|
|
266
|
+
# CUDA devices may fail if Jax cannot find a CUDA Toolkit
|
|
267
|
+
test_devices = get_test_devices()
|
|
268
|
+
jax_compatible_devices = []
|
|
269
|
+
jax_compatible_cuda_devices = []
|
|
270
|
+
for d in test_devices:
|
|
271
|
+
try:
|
|
272
|
+
with jax.default_device(wp.device_to_jax(d)):
|
|
273
|
+
j = jax.numpy.arange(10, dtype=jax.numpy.float32)
|
|
274
|
+
j += 1
|
|
275
|
+
jax_compatible_devices.append(d)
|
|
276
|
+
if d.is_cuda:
|
|
277
|
+
jax_compatible_cuda_devices.append(d)
|
|
278
|
+
except Exception as e:
|
|
279
|
+
print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}")
|
|
280
|
+
|
|
281
|
+
add_function_test(TestJax, "test_dtype_from_jax", test_dtype_from_jax, devices=None)
|
|
282
|
+
add_function_test(TestJax, "test_dtype_to_jax", test_dtype_to_jax, devices=None)
|
|
283
|
+
|
|
284
|
+
if jax_compatible_devices:
|
|
285
|
+
add_function_test(TestJax, "test_device_conversion", test_device_conversion, devices=jax_compatible_devices)
|
|
286
|
+
|
|
287
|
+
if jax_compatible_cuda_devices:
|
|
288
|
+
add_function_test(TestJax, "test_jax_kernel_basic", test_jax_kernel_basic, devices=jax_compatible_cuda_devices)
|
|
289
|
+
add_function_test(
|
|
290
|
+
TestJax, "test_jax_kernel_scalar", test_jax_kernel_scalar, devices=jax_compatible_cuda_devices
|
|
291
|
+
)
|
|
292
|
+
add_function_test(
|
|
293
|
+
TestJax, "test_jax_kernel_vecmat", test_jax_kernel_vecmat, devices=jax_compatible_cuda_devices
|
|
294
|
+
)
|
|
295
|
+
add_function_test(
|
|
296
|
+
TestJax, "test_jax_kernel_multiarg", test_jax_kernel_multiarg, devices=jax_compatible_cuda_devices
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
except Exception as e:
|
|
300
|
+
print(f"Skipping Jax tests due to exception: {e}")
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
if __name__ == "__main__":
|
|
304
|
+
wp.build.clear_kernel_cache()
|
|
305
|
+
unittest.main(verbosity=2)
|