warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +2220 -313
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1075 -0
- warp/_src/build.py +618 -0
- warp/_src/build_dll.py +640 -0
- warp/{builtins.py → _src/builtins.py} +1497 -226
- warp/_src/codegen.py +4359 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +57 -0
- warp/_src/context.py +8294 -0
- warp/_src/dlpack.py +462 -0
- warp/_src/fabric.py +355 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +508 -0
- warp/_src/fem/cache.py +687 -0
- warp/_src/fem/dirichlet.py +188 -0
- warp/{fem → _src/fem}/domain.py +40 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +701 -0
- warp/{fem → _src/fem}/field/nodal_field.py +30 -15
- warp/{fem → _src/fem}/field/restriction.py +1 -1
- warp/{fem → _src/fem}/field/virtual.py +53 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
- warp/_src/fem/geometry/closest_point.py +97 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
- warp/{fem → _src/fem}/geometry/element.py +32 -10
- warp/{fem → _src/fem}/geometry/geometry.py +48 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
- warp/{fem → _src/fem}/geometry/partition.py +121 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
- warp/{fem → _src/fem}/integrate.py +164 -158
- warp/_src/fem/linalg.py +383 -0
- warp/_src/fem/operator.py +396 -0
- warp/_src/fem/polynomial.py +229 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
- warp/_src/fem/space/basis_space.py +679 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
- warp/{fem → _src/fem}/space/function_space.py +14 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
- warp/{fem → _src/fem}/space/partition.py +117 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/restriction.py +66 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
- warp/_src/fem/space/topology.py +459 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
- warp/_src/fem/types.py +112 -0
- warp/_src/fem/utils.py +486 -0
- warp/_src/jax.py +186 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +387 -0
- warp/_src/jax_experimental/ffi.py +1284 -0
- warp/_src/jax_experimental/xla_ffi.py +656 -0
- warp/_src/marching_cubes.py +708 -0
- warp/_src/math.py +414 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +163 -0
- warp/_src/optim/linear.py +1606 -0
- warp/_src/optim/sgd.py +112 -0
- warp/_src/paddle.py +406 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +289 -0
- warp/_src/render/render_opengl.py +3636 -0
- warp/_src/render/render_usd.py +937 -0
- warp/_src/render/utils.py +160 -0
- warp/_src/sparse.py +2716 -0
- warp/_src/tape.py +1206 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +391 -0
- warp/_src/types.py +5870 -0
- warp/_src/utils.py +1693 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -471
- warp/codegen.py +6 -4246
- warp/constants.py +6 -39
- warp/context.py +12 -7851
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +3 -2
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -342
- warp/jax_experimental/ffi.py +17 -853
- warp/jax_experimental/xla_ffi.py +5 -596
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +316 -39
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sort.cu +22 -13
- warp/native/sort.h +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +837 -70
- warp/native/tile_radix_sort.h +1 -1
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -53
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +60 -32
- warp/native/warp.cu +313 -201
- warp/native/warp.h +14 -11
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3616
- warp/render/render_usd.py +6 -918
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_hash_grid.py +38 -0
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +14 -14
- warp/tests/interop/test_jax.py +1382 -79
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +529 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +34 -15
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +60 -14
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +49 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_tuple.py +96 -0
- warp/tests/test_types.py +82 -9
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +239 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5750
- warp/utils.py +10 -1659
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
- warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.0.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/tests/test_sparse.py
CHANGED
|
@@ -18,6 +18,7 @@ import unittest
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
|
|
20
20
|
import warp as wp
|
|
21
|
+
from warp._src.sparse import bsr_set_zero
|
|
21
22
|
from warp.sparse import (
|
|
22
23
|
bsr_assign,
|
|
23
24
|
bsr_axpy,
|
|
@@ -59,6 +60,17 @@ def _triplets_to_dense(shape, rows, cols, values):
|
|
|
59
60
|
return mat
|
|
60
61
|
|
|
61
62
|
|
|
63
|
+
def _bsr_pruned(bsr):
|
|
64
|
+
return bsr_from_triplets(
|
|
65
|
+
rows_of_blocks=bsr.nrow,
|
|
66
|
+
cols_of_blocks=bsr.ncol,
|
|
67
|
+
rows=bsr.uncompress_rows(),
|
|
68
|
+
columns=bsr.columns,
|
|
69
|
+
values=bsr.values,
|
|
70
|
+
prune_numerical_zeros=True,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
62
74
|
def _bsr_to_dense(bsr):
|
|
63
75
|
mat = np.zeros(bsr.shape)
|
|
64
76
|
|
|
@@ -113,7 +125,7 @@ def test_bsr_from_triplets(test, device):
|
|
|
113
125
|
|
|
114
126
|
ref = _triplets_to_dense(shape, rows, cols, vals)
|
|
115
127
|
|
|
116
|
-
bsr = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=float), device=device)
|
|
128
|
+
bsr = bsr_zeros(nrow, ncol, wp._src.types.matrix(shape=block_shape, dtype=float), device=device)
|
|
117
129
|
bsr_set_from_triplets(bsr, rows, cols, vals)
|
|
118
130
|
test.assertEqual(bsr.block_size, block_shape[0] * block_shape[1])
|
|
119
131
|
|
|
@@ -218,7 +230,7 @@ def test_bsr_get_set_diag(test, device):
|
|
|
218
230
|
vals_np = rng.random(size=(nnz, block_shape[0], block_shape[1]))
|
|
219
231
|
vals = wp.array(vals_np, dtype=float, device=device)
|
|
220
232
|
|
|
221
|
-
bsr = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=float), device=device)
|
|
233
|
+
bsr = bsr_zeros(nrow, ncol, wp._src.types.matrix(shape=block_shape, dtype=float), device=device)
|
|
222
234
|
bsr_set_from_triplets(bsr, rows, cols, vals)
|
|
223
235
|
|
|
224
236
|
diag = bsr_get_diag(bsr)
|
|
@@ -274,14 +286,13 @@ def test_bsr_split_merge(test, device):
|
|
|
274
286
|
block_shape = (4, 2)
|
|
275
287
|
nrow = 4
|
|
276
288
|
ncol = 8
|
|
277
|
-
shape = (block_shape[0] * nrow, block_shape[1] * ncol)
|
|
278
289
|
n = 20
|
|
279
290
|
|
|
280
291
|
rows = wp.array(rng.integers(0, high=nrow, size=n, dtype=int), dtype=int, device=device)
|
|
281
292
|
cols = wp.array(rng.integers(0, high=ncol, size=n, dtype=int), dtype=int, device=device)
|
|
282
293
|
vals = wp.array(rng.random(size=(n, block_shape[0], block_shape[1])), dtype=float, device=device)
|
|
283
294
|
|
|
284
|
-
bsr = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=float), device=device)
|
|
295
|
+
bsr = bsr_zeros(nrow, ncol, wp._src.types.matrix(shape=block_shape, dtype=float), device=device)
|
|
285
296
|
bsr_set_from_triplets(bsr, rows, cols, vals)
|
|
286
297
|
ref = _bsr_to_dense(bsr)
|
|
287
298
|
|
|
@@ -359,13 +370,13 @@ def make_test_bsr_transpose(block_shape, scalar_type):
|
|
|
359
370
|
vals_np = rng.random(size=(nnz, block_shape[0], block_shape[1]))
|
|
360
371
|
vals = wp.array(vals_np, dtype=scalar_type, device=device).reshape((nnz, block_shape[0], block_shape[1]))
|
|
361
372
|
|
|
362
|
-
bsr = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
|
|
373
|
+
bsr = bsr_zeros(nrow, ncol, wp._src.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
|
|
363
374
|
bsr_set_from_triplets(bsr, rows, cols, vals)
|
|
364
375
|
ref = 2.0 * np.transpose(_bsr_to_dense(bsr))
|
|
365
376
|
|
|
366
|
-
bsr_transposed = (2.0 * bsr).transpose()
|
|
377
|
+
bsr_transposed = (2.0 * bsr).transpose().eval()
|
|
367
378
|
|
|
368
|
-
res = _bsr_to_dense(bsr_transposed
|
|
379
|
+
res = _bsr_to_dense(bsr_transposed)
|
|
369
380
|
assert_np_equal(res, ref, 0.0001)
|
|
370
381
|
|
|
371
382
|
if block_shape[0] != block_shape[-1]:
|
|
@@ -373,6 +384,22 @@ def make_test_bsr_transpose(block_shape, scalar_type):
|
|
|
373
384
|
with test.assertRaisesRegex(ValueError, "Destination block shape must be"):
|
|
374
385
|
bsr_set_transpose(dest=bsr, src=bsr)
|
|
375
386
|
|
|
387
|
+
# test masked transpose
|
|
388
|
+
# remove some non zeros from src and dest matrices
|
|
389
|
+
bsr_set_from_triplets(bsr, rows[:3], cols[:3], vals[:3])
|
|
390
|
+
bsr_transposed = bsr_from_triplets(
|
|
391
|
+
bsr_transposed.nrow,
|
|
392
|
+
bsr_transposed.ncol,
|
|
393
|
+
bsr_transposed.uncompress_rows()[:3],
|
|
394
|
+
bsr_transposed.columns[:3],
|
|
395
|
+
bsr_transposed.values[:3],
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
assert_np_equal(bsr_transposed.uncompress_rows().numpy()[:3], [0, 1, 1])
|
|
399
|
+
assert_np_equal(bsr_transposed.columns.numpy()[:3], [2, 0, 2])
|
|
400
|
+
bsr_set_transpose(bsr_transposed, bsr, masked=True)
|
|
401
|
+
assert _bsr_pruned(bsr_transposed).nnz_sync() == 2
|
|
402
|
+
|
|
376
403
|
return test_bsr_transpose
|
|
377
404
|
|
|
378
405
|
|
|
@@ -392,7 +419,7 @@ def make_test_bsr_axpy(block_shape, scalar_type):
|
|
|
392
419
|
x_vals = wp.array(rng.random(size=(nnz, block_shape[0], block_shape[1])), dtype=scalar_type, device=device)
|
|
393
420
|
x_vals = x_vals.reshape((nnz, block_shape[0], block_shape[1]))
|
|
394
421
|
|
|
395
|
-
x = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
|
|
422
|
+
x = bsr_zeros(nrow, ncol, wp._src.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
|
|
396
423
|
bsr_set_from_triplets(x, x_rows, x_cols, x_vals)
|
|
397
424
|
|
|
398
425
|
y_rows = wp.array(rng.integers(0, high=nrow, size=nnz, dtype=int), dtype=int, device=device)
|
|
@@ -400,7 +427,7 @@ def make_test_bsr_axpy(block_shape, scalar_type):
|
|
|
400
427
|
y_vals = wp.array(rng.random(size=(nnz, block_shape[0], block_shape[1])), dtype=scalar_type, device=device)
|
|
401
428
|
y_vals = y_vals.reshape((nnz, block_shape[0], block_shape[1]))
|
|
402
429
|
|
|
403
|
-
y = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
|
|
430
|
+
y = bsr_zeros(nrow, ncol, wp._src.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
|
|
404
431
|
bsr_set_from_triplets(y, y_rows, y_cols, y_vals)
|
|
405
432
|
|
|
406
433
|
work_arrays = bsr_axpy_work_arrays()
|
|
@@ -457,7 +484,7 @@ def make_test_bsr_mm(block_shape, scalar_type):
|
|
|
457
484
|
x_vals = wp.array(rng.random(size=(nnz, x_block_shape[0], x_block_shape[1])), dtype=scalar_type, device=device)
|
|
458
485
|
x_vals = x_vals.reshape((nnz, x_block_shape[0], x_block_shape[1]))
|
|
459
486
|
|
|
460
|
-
x = bsr_zeros(x_nrow, x_ncol, wp.types.matrix(shape=x_block_shape, dtype=scalar_type), device=device)
|
|
487
|
+
x = bsr_zeros(x_nrow, x_ncol, wp._src.types.matrix(shape=x_block_shape, dtype=scalar_type), device=device)
|
|
461
488
|
bsr_set_from_triplets(x, x_rows, x_cols, x_vals)
|
|
462
489
|
|
|
463
490
|
y_rows = wp.array(rng.integers(0, high=y_nrow, size=nnz, dtype=int), dtype=int, device=device)
|
|
@@ -465,7 +492,7 @@ def make_test_bsr_mm(block_shape, scalar_type):
|
|
|
465
492
|
y_vals = wp.array(rng.random(size=(nnz, y_block_shape[0], y_block_shape[1])), dtype=scalar_type, device=device)
|
|
466
493
|
y_vals = y_vals.reshape((nnz, y_block_shape[0], y_block_shape[1]))
|
|
467
494
|
|
|
468
|
-
y = bsr_zeros(y_nrow, y_ncol, wp.types.matrix(shape=y_block_shape, dtype=scalar_type), device=device)
|
|
495
|
+
y = bsr_zeros(y_nrow, y_ncol, wp._src.types.matrix(shape=y_block_shape, dtype=scalar_type), device=device)
|
|
469
496
|
bsr_set_from_triplets(y, y_rows, y_cols, y_vals)
|
|
470
497
|
|
|
471
498
|
z_rows = wp.array(rng.integers(0, high=z_nrow, size=nnz, dtype=int), dtype=int, device=device)
|
|
@@ -473,7 +500,7 @@ def make_test_bsr_mm(block_shape, scalar_type):
|
|
|
473
500
|
z_vals = wp.array(rng.random(size=(nnz, z_block_shape[0], z_block_shape[1])), dtype=scalar_type, device=device)
|
|
474
501
|
z_vals = z_vals.reshape((nnz, z_block_shape[0], z_block_shape[1]))
|
|
475
502
|
|
|
476
|
-
z = bsr_zeros(z_nrow, z_ncol, wp.types.matrix(shape=z_block_shape, dtype=scalar_type), device=device)
|
|
503
|
+
z = bsr_zeros(z_nrow, z_ncol, wp._src.types.matrix(shape=z_block_shape, dtype=scalar_type), device=device)
|
|
477
504
|
bsr_set_from_triplets(z, z_rows, z_cols, z_vals)
|
|
478
505
|
|
|
479
506
|
work_arrays = bsr_mm_work_arrays()
|
|
@@ -544,7 +571,7 @@ def make_test_bsr_mv(block_shape, scalar_type):
|
|
|
544
571
|
A_vals = wp.array(rng.random(size=(nnz, block_shape[0], block_shape[1])), dtype=scalar_type, device=device)
|
|
545
572
|
A_vals = A_vals.reshape((nnz, block_shape[0], block_shape[1]))
|
|
546
573
|
|
|
547
|
-
A = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
|
|
574
|
+
A = bsr_zeros(nrow, ncol, wp._src.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
|
|
548
575
|
bsr_set_from_triplets(A, A_rows, A_cols, A_vals)
|
|
549
576
|
|
|
550
577
|
if block_shape[1] == 1:
|
|
@@ -664,6 +691,83 @@ def make_test_bsr_multiply_deep(block_shape, scalar_type):
|
|
|
664
691
|
return test_bsr_multiply_deep
|
|
665
692
|
|
|
666
693
|
|
|
694
|
+
def test_bsr_mm_max_new_nnz(test, device):
|
|
695
|
+
"""Test that BSR matrix multiplication with max_new_nnz works"""
|
|
696
|
+
A = bsr_from_triplets(
|
|
697
|
+
2,
|
|
698
|
+
2,
|
|
699
|
+
wp.array([0, 0, 1, 1], dtype=int, device=device),
|
|
700
|
+
wp.array([0, 1, 0, 1], dtype=int, device=device),
|
|
701
|
+
wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32, device=device),
|
|
702
|
+
)
|
|
703
|
+
B = bsr_from_triplets(
|
|
704
|
+
2,
|
|
705
|
+
2,
|
|
706
|
+
wp.array([0, 0, 1, 1], dtype=int, device=device),
|
|
707
|
+
wp.array([0, 1, 0, 1], dtype=int, device=device),
|
|
708
|
+
wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32, device=device),
|
|
709
|
+
)
|
|
710
|
+
C = bsr_zeros(2, 2, wp.float32, device=device)
|
|
711
|
+
|
|
712
|
+
# max_new_nnz big enough
|
|
713
|
+
bsr_mm(A, B, C, max_new_nnz=4)
|
|
714
|
+
test.assertEqual(C.nnz_sync(), 4)
|
|
715
|
+
|
|
716
|
+
bsr_set_zero(C)
|
|
717
|
+
test.assertEqual(C.nnz_sync(), 0)
|
|
718
|
+
|
|
719
|
+
# max_new_nnz too small, check warning
|
|
720
|
+
capture = StdOutCapture()
|
|
721
|
+
capture.begin()
|
|
722
|
+
bsr_mm(A, B, C, max_new_nnz=2)
|
|
723
|
+
test.assertEqual(C.nnz_sync(), 2)
|
|
724
|
+
output = capture.end()
|
|
725
|
+
|
|
726
|
+
# Check that the output contains warnings about "max_new_nnz" being exceeded.
|
|
727
|
+
# Older Windows C runtimes have a bug where stdout sometimes does not get properly flushed.
|
|
728
|
+
if output != "" or sys.platform != "win32":
|
|
729
|
+
test.assertRegex(output, r"exceeded")
|
|
730
|
+
|
|
731
|
+
|
|
732
|
+
def test_capturability(test, device):
|
|
733
|
+
"""Test that BSR operations are graph-capturable"""
|
|
734
|
+
|
|
735
|
+
N = 5
|
|
736
|
+
M = 3
|
|
737
|
+
|
|
738
|
+
C = bsr_diag(wp.zeros(N, dtype=wp.mat33, device=device))
|
|
739
|
+
|
|
740
|
+
rows = wp.array([3, 4, 2, 0, 1], dtype=int, device=device)
|
|
741
|
+
columns = wp.array([2, 0, 1, 2, 1], dtype=int, device=device)
|
|
742
|
+
values = wp.ones(5, dtype=wp.mat33, device=device)
|
|
743
|
+
|
|
744
|
+
def test_body():
|
|
745
|
+
A = bsr_from_triplets(
|
|
746
|
+
N,
|
|
747
|
+
M,
|
|
748
|
+
rows=rows,
|
|
749
|
+
columns=columns,
|
|
750
|
+
values=values,
|
|
751
|
+
)
|
|
752
|
+
B = A + bsr_copy(A * 2.0)
|
|
753
|
+
bsr_mm(A, bsr_transposed(B), C, max_new_nnz=N * N)
|
|
754
|
+
|
|
755
|
+
# ensure necessary modules are loaded and reset result
|
|
756
|
+
test_body()
|
|
757
|
+
bsr_set_zero(C)
|
|
758
|
+
test.assertEqual(C.nnz_sync(), 0)
|
|
759
|
+
|
|
760
|
+
with wp.ScopedDevice(device):
|
|
761
|
+
with wp.ScopedCapture(force_module_load=False) as capture:
|
|
762
|
+
test_body()
|
|
763
|
+
|
|
764
|
+
assert_array_equal(bsr_get_diag(C), wp.zeros(N, dtype=wp.mat33, device=device))
|
|
765
|
+
|
|
766
|
+
wp.capture_launch(capture.graph)
|
|
767
|
+
test.assertEqual(C.nnz_sync(), 9)
|
|
768
|
+
assert_array_equal(bsr_get_diag(C), wp.full(N, value=wp.mat33(9.0), dtype=wp.mat33, device=device))
|
|
769
|
+
|
|
770
|
+
|
|
667
771
|
devices = get_test_devices()
|
|
668
772
|
cuda_test_devices = get_selected_cuda_test_devices()
|
|
669
773
|
|
|
@@ -676,7 +780,9 @@ class TestSparse(unittest.TestCase):
|
|
|
676
780
|
diag_bsr = bsr_diag(diag=np.eye(bsize, dtype=float) * 2.0, rows_of_blocks=nrow)
|
|
677
781
|
diag_copy = bsr_copy(diag_bsr, scalar_type=wp.float64)
|
|
678
782
|
|
|
679
|
-
self.assertTrue(
|
|
783
|
+
self.assertTrue(
|
|
784
|
+
wp._src.types.types_equal(diag_copy.values.dtype, wp.mat(shape=(bsize, bsize), dtype=wp.float64))
|
|
785
|
+
)
|
|
680
786
|
bsr_scale(x=diag_copy, alpha=0.5)
|
|
681
787
|
|
|
682
788
|
res = _bsr_to_dense(diag_copy)
|
|
@@ -686,7 +792,10 @@ class TestSparse(unittest.TestCase):
|
|
|
686
792
|
bsr_scale(x=diag_copy, alpha=0.0)
|
|
687
793
|
self.assertEqual(diag_copy.nrow, nrow)
|
|
688
794
|
self.assertEqual(diag_copy.ncol, nrow)
|
|
689
|
-
self.assertEqual(diag_copy.nnz,
|
|
795
|
+
self.assertEqual(diag_copy.nnz, diag_bsr.nnz)
|
|
796
|
+
|
|
797
|
+
diag_pruned = _bsr_pruned(diag_copy)
|
|
798
|
+
self.assertEqual(diag_pruned.nnz_sync(), 0)
|
|
690
799
|
|
|
691
800
|
|
|
692
801
|
add_function_test(TestSparse, "test_csr_from_triplets", test_csr_from_triplets, devices=devices)
|
|
@@ -728,6 +837,8 @@ add_function_test(TestSparse, "test_csr_mv", make_test_bsr_mv((1, 1), wp.float32
|
|
|
728
837
|
add_function_test(TestSparse, "test_bsr_mv_1_3", make_test_bsr_mv((1, 3), wp.float32), devices=devices)
|
|
729
838
|
add_function_test(TestSparse, "test_bsr_mv_3_3", make_test_bsr_mv((3, 3), wp.float64), devices=devices)
|
|
730
839
|
|
|
840
|
+
add_function_test(TestSparse, "test_capturability", test_capturability, devices=cuda_test_devices)
|
|
841
|
+
add_function_test(TestSparse, "test_bsr_mm_max_new_nnz", test_bsr_mm_max_new_nnz, devices=devices, check_output=False)
|
|
731
842
|
|
|
732
843
|
if __name__ == "__main__":
|
|
733
844
|
wp.clear_kernel_cache()
|