warp-lang 1.7.2rc1__py3-none-macosx_10_13_universal2.whl → 1.8.1__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +130 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +272 -104
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +770 -238
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +99 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +181 -95
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +210 -67
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +103 -4
- warp/native/builtin.h +182 -35
- warp/native/coloring.cpp +6 -2
- warp/native/cuda_util.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +5 -5
- warp/native/mat.h +8 -13
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/mesh.h +1 -1
- warp/native/quat.h +34 -6
- warp/native/rand.h +7 -7
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/svd.h +23 -8
- warp/native/tile.h +603 -73
- warp/native/tile_radix_sort.h +1112 -0
- warp/native/tile_reduce.h +239 -13
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +10 -20
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +588 -52
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +110 -80
- warp/render/render_usd.py +124 -62
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +253 -80
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +761 -322
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +54 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +91 -2
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +312 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +23 -24
- warp/tests/test_quat.py +28 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +83 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_static.py +48 -0
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tape.py +38 -0
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/test_vec.py +38 -408
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +438 -131
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_reduce.py +307 -5
- warp/tests/tile/test_tile_shared_memory.py +136 -7
- warp/tests/tile/test_tile_sort.py +121 -0
- warp/tests/unittest_suites.py +14 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +189 -175
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/tests/test_rounding.py
CHANGED
|
@@ -47,41 +47,13 @@ def test_kernel(
|
|
|
47
47
|
|
|
48
48
|
|
|
49
49
|
def test_rounding(test, device):
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
3.1,
|
|
58
|
-
2.9,
|
|
59
|
-
2.5,
|
|
60
|
-
2.1,
|
|
61
|
-
1.9,
|
|
62
|
-
1.5,
|
|
63
|
-
1.1,
|
|
64
|
-
0.9,
|
|
65
|
-
0.5,
|
|
66
|
-
0.1,
|
|
67
|
-
-0.1,
|
|
68
|
-
-0.5,
|
|
69
|
-
-0.9,
|
|
70
|
-
-1.1,
|
|
71
|
-
-1.5,
|
|
72
|
-
-1.9,
|
|
73
|
-
-2.1,
|
|
74
|
-
-2.5,
|
|
75
|
-
-2.9,
|
|
76
|
-
-3.1,
|
|
77
|
-
-3.5,
|
|
78
|
-
-3.9,
|
|
79
|
-
-4.1,
|
|
80
|
-
-4.5,
|
|
81
|
-
-4.9,
|
|
82
|
-
],
|
|
83
|
-
dtype=np.float32,
|
|
84
|
-
)
|
|
50
|
+
# fmt: off
|
|
51
|
+
nx = np.array([
|
|
52
|
+
4.9, 4.5, 4.1, 3.9, 3.5, 3.1, 2.9, 2.5, 2.1, 1.9,
|
|
53
|
+
1.5, 1.1, 0.9, 0.5, 0.1, -0.1, -0.5, -0.9, -1.1, -1.5,
|
|
54
|
+
-1.9, -2.1, -2.5, -2.9, -3.1, -3.5, -3.9, -4.1, -4.5, -4.9
|
|
55
|
+
], dtype=np.float32)
|
|
56
|
+
# fmt: on
|
|
85
57
|
|
|
86
58
|
x = wp.array(nx, device=device)
|
|
87
59
|
N = len(x)
|
|
@@ -149,10 +121,10 @@ def test_rounding(test, device):
|
|
|
149
121
|
assert_np_equal(tab, golden, tol=1e-6)
|
|
150
122
|
|
|
151
123
|
if print_results:
|
|
152
|
-
np.set_printoptions(formatter={"float": lambda x: "{:6.1f}".
|
|
124
|
+
np.set_printoptions(formatter={"float": lambda x: f"{x:6.1f}".replace(".0", ".")})
|
|
153
125
|
|
|
154
126
|
print("----------------------------------------------")
|
|
155
|
-
print("
|
|
127
|
+
print(f" {'x ':>5s} {'round':>5s} {'rint':>5s} {'trunc':>5s} {'cast':>5s} {'floor':>5s} {'ceil':>5s}")
|
|
156
128
|
print(tab)
|
|
157
129
|
print("----------------------------------------------")
|
|
158
130
|
|
|
@@ -166,7 +138,7 @@ def test_rounding(test, device):
|
|
|
166
138
|
nx_frac = np.modf(nx)[0]
|
|
167
139
|
|
|
168
140
|
tab = np.stack([nx, nx_round, nx_rint, nx_trunc, nx_fix, nx_floor, nx_ceil, nx_frac], axis=1)
|
|
169
|
-
print("
|
|
141
|
+
print(f" {'x ':>5s} {'round':>5s} {'rint':>5s} {'trunc':>5s} {'fix':>5s} {'floor':>5s} {'ceil':>5s}")
|
|
170
142
|
print(tab)
|
|
171
143
|
print("----------------------------------------------")
|
|
172
144
|
|
|
@@ -48,7 +48,7 @@ def test_runlength_encode_error_insufficient_storage(test, device):
|
|
|
48
48
|
run_lengths = wp.empty(123, dtype=int, device=device)
|
|
49
49
|
with test.assertRaisesRegex(
|
|
50
50
|
RuntimeError,
|
|
51
|
-
r"Output array storage sizes must be at least equal to value_count$",
|
|
51
|
+
r"Output array storage sizes must be at least equal to value_count \(123\)$",
|
|
52
52
|
):
|
|
53
53
|
runlength_encode(values, run_values, run_lengths)
|
|
54
54
|
|
|
@@ -57,7 +57,7 @@ def test_runlength_encode_error_insufficient_storage(test, device):
|
|
|
57
57
|
run_lengths = wp.empty(1, dtype=int, device="cpu")
|
|
58
58
|
with test.assertRaisesRegex(
|
|
59
59
|
RuntimeError,
|
|
60
|
-
r"Output array storage sizes must be at least equal to value_count$",
|
|
60
|
+
r"Output array storage sizes must be at least equal to value_count \(123\)$",
|
|
61
61
|
):
|
|
62
62
|
runlength_encode(values, run_values, run_lengths)
|
|
63
63
|
|
|
@@ -68,7 +68,7 @@ def test_runlength_encode_error_dtypes_mismatch(test, device):
|
|
|
68
68
|
run_lengths = wp.empty_like(values, device=device)
|
|
69
69
|
with test.assertRaisesRegex(
|
|
70
70
|
RuntimeError,
|
|
71
|
-
r"values and run_values data types do not match$",
|
|
71
|
+
r"values and run_values data types do not match \(int32 vs float32\)$",
|
|
72
72
|
):
|
|
73
73
|
runlength_encode(values, run_values, run_lengths)
|
|
74
74
|
|
|
@@ -102,7 +102,7 @@ def test_runlength_encode_error_unsupported_dtype(test, device):
|
|
|
102
102
|
run_lengths = wp.empty(123, dtype=int, device=device)
|
|
103
103
|
with test.assertRaisesRegex(
|
|
104
104
|
RuntimeError,
|
|
105
|
-
r"Unsupported data type$",
|
|
105
|
+
r"Unsupported data type: float32$",
|
|
106
106
|
):
|
|
107
107
|
runlength_encode(values, run_values, run_lengths)
|
|
108
108
|
|
|
@@ -118,7 +118,7 @@ class TestRunlengthEncode(unittest.TestCase):
|
|
|
118
118
|
run_lengths = wp.empty_like(values, device="cuda:0")
|
|
119
119
|
with self.assertRaisesRegex(
|
|
120
120
|
RuntimeError,
|
|
121
|
-
r"
|
|
121
|
+
r"run_values, run_lengths and values storage devices do not match$",
|
|
122
122
|
):
|
|
123
123
|
runlength_encode(values, run_values, run_lengths)
|
|
124
124
|
|
|
@@ -127,7 +127,7 @@ class TestRunlengthEncode(unittest.TestCase):
|
|
|
127
127
|
run_lengths = wp.empty_like(values, device="cuda:0")
|
|
128
128
|
with self.assertRaisesRegex(
|
|
129
129
|
RuntimeError,
|
|
130
|
-
r"
|
|
130
|
+
r"run_values, run_lengths and values storage devices do not match$",
|
|
131
131
|
):
|
|
132
132
|
runlength_encode(values, run_values, run_lengths)
|
|
133
133
|
|
|
@@ -136,7 +136,7 @@ class TestRunlengthEncode(unittest.TestCase):
|
|
|
136
136
|
run_lengths = wp.empty_like(values, device="cpu")
|
|
137
137
|
with self.assertRaisesRegex(
|
|
138
138
|
RuntimeError,
|
|
139
|
-
r"
|
|
139
|
+
r"run_values, run_lengths and values storage devices do not match$",
|
|
140
140
|
):
|
|
141
141
|
runlength_encode(values, run_values, run_lengths)
|
|
142
142
|
|
warp/tests/test_smoothstep.py
CHANGED
|
@@ -78,7 +78,7 @@ def test_smoothstep(test, device):
|
|
|
78
78
|
a = wp.array([test_data.a], dtype=data_type, device=device, requires_grad=True)
|
|
79
79
|
b = wp.array([test_data.b], dtype=data_type, device=device, requires_grad=True)
|
|
80
80
|
t = wp.array([test_data.t], dtype=float, device=device, requires_grad=True)
|
|
81
|
-
out = wp.array([0] * wp.types.
|
|
81
|
+
out = wp.array([0] * wp.types.type_size(data_type), dtype=data_type, device=device, requires_grad=True)
|
|
82
82
|
|
|
83
83
|
with wp.Tape() as tape:
|
|
84
84
|
wp.launch(kernel, dim=1, inputs=[a, b, t, out], device=device)
|
warp/tests/test_sparse.py
CHANGED
|
@@ -130,6 +130,80 @@ def test_bsr_from_triplets(test, device):
|
|
|
130
130
|
)
|
|
131
131
|
test.assertEqual(bsr.nnz, 0)
|
|
132
132
|
|
|
133
|
+
# test passing indices with wrong data ty[e]
|
|
134
|
+
rows = wp.array(rows.numpy().astype(float), dtype=float, device=device)
|
|
135
|
+
cols = wp.array(cols.numpy().astype(float), dtype=float, device=device)
|
|
136
|
+
with test.assertRaisesRegex(
|
|
137
|
+
TypeError,
|
|
138
|
+
r"Rows and columns arrays must be of type int32$",
|
|
139
|
+
):
|
|
140
|
+
bsr_set_from_triplets(bsr, rows, cols, vals)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def test_bsr_from_triplets_prune_numerical_zeros(test, device):
|
|
144
|
+
rows = wp.array([1, 0, 2, 3], dtype=int)
|
|
145
|
+
cols = wp.array([0, 1, 2, 3], dtype=int)
|
|
146
|
+
vals = wp.zeros(len(rows), dtype=float)
|
|
147
|
+
|
|
148
|
+
A = bsr_from_triplets(
|
|
149
|
+
rows_of_blocks=12, # Number of rows of blocks
|
|
150
|
+
cols_of_blocks=12, # Number of columns of blocks
|
|
151
|
+
rows=rows, # Row indices
|
|
152
|
+
columns=cols, # Column indices
|
|
153
|
+
values=vals, # Block values
|
|
154
|
+
prune_numerical_zeros=False,
|
|
155
|
+
)
|
|
156
|
+
assert A.nnz_sync() == 4
|
|
157
|
+
|
|
158
|
+
A = bsr_from_triplets(
|
|
159
|
+
rows_of_blocks=12, # Number of rows of blocks
|
|
160
|
+
cols_of_blocks=12, # Number of columns of blocks
|
|
161
|
+
rows=rows, # Row indices
|
|
162
|
+
columns=cols, # Column indices
|
|
163
|
+
values=vals, # Block values
|
|
164
|
+
prune_numerical_zeros=True,
|
|
165
|
+
)
|
|
166
|
+
assert A.nnz_sync() == 0
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def test_bsr_from_triplets_gradient(test, device):
|
|
170
|
+
rng = np.random.default_rng(123)
|
|
171
|
+
|
|
172
|
+
block_shape = (3, 3)
|
|
173
|
+
nrow = 2
|
|
174
|
+
ncol = 2
|
|
175
|
+
|
|
176
|
+
n = 4
|
|
177
|
+
rows = wp.array([1, 0, 0, 1], dtype=int, device=device)
|
|
178
|
+
cols = wp.array([0, 1, 0, 0], dtype=int, device=device)
|
|
179
|
+
|
|
180
|
+
vals = wp.array(
|
|
181
|
+
rng.random(size=(n, block_shape[0], block_shape[1])), dtype=wp.mat33, device=device, requires_grad=True
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
with wp.Tape() as tape:
|
|
185
|
+
mat = bsr_from_triplets(nrow, ncol, rows, cols, vals)
|
|
186
|
+
|
|
187
|
+
assert mat.nnz_sync() == 3
|
|
188
|
+
|
|
189
|
+
zero_block = np.zeros((3, 3))
|
|
190
|
+
ones_block = np.ones((3, 3))
|
|
191
|
+
|
|
192
|
+
mat.values.grad[0:1].fill_(1.0)
|
|
193
|
+
tape.backward()
|
|
194
|
+
assert_np_equal(vals.grad.numpy(), np.stack((zero_block, zero_block, ones_block, zero_block)))
|
|
195
|
+
tape.zero()
|
|
196
|
+
|
|
197
|
+
mat.values.grad[1:2].fill_(1.0)
|
|
198
|
+
tape.backward()
|
|
199
|
+
assert_np_equal(vals.grad.numpy(), np.stack((zero_block, ones_block, zero_block, zero_block)))
|
|
200
|
+
tape.zero()
|
|
201
|
+
|
|
202
|
+
mat.values.grad[2:3].fill_(1.0)
|
|
203
|
+
tape.backward()
|
|
204
|
+
assert_np_equal(vals.grad.numpy(), np.stack((ones_block, zero_block, zero_block, ones_block)))
|
|
205
|
+
tape.zero()
|
|
206
|
+
|
|
133
207
|
|
|
134
208
|
def test_bsr_get_set_diag(test, device):
|
|
135
209
|
rng = np.random.default_rng(123)
|
|
@@ -191,7 +265,7 @@ def test_bsr_get_set_diag(test, device):
|
|
|
191
265
|
assert_np_equal(diag_bsr.values.numpy(), np.broadcast_to(np.eye(4), shape=(nrow, 4, 4)), tol=0.000001)
|
|
192
266
|
|
|
193
267
|
diag_csr = bsr_identity(nrow, block_type=wp.float64, device=device)
|
|
194
|
-
|
|
268
|
+
np.testing.assert_array_equal(diag_csr.values.numpy(), np.ones(nrow, dtype=float))
|
|
195
269
|
|
|
196
270
|
|
|
197
271
|
def test_bsr_split_merge(test, device):
|
|
@@ -239,7 +313,7 @@ def test_bsr_split_merge(test, device):
|
|
|
239
313
|
|
|
240
314
|
with test.assertRaisesRegex(
|
|
241
315
|
ValueError,
|
|
242
|
-
"The requested block shape does not evenly divide the source matrix",
|
|
316
|
+
r"The requested block shape \(32, 32\) does not evenly divide the source matrix of total size \(16, 16\)",
|
|
243
317
|
):
|
|
244
318
|
bsr_copy(bsr, block_shape=(32, 32))
|
|
245
319
|
|
|
@@ -556,9 +630,16 @@ class TestSparse(unittest.TestCase):
|
|
|
556
630
|
|
|
557
631
|
add_function_test(TestSparse, "test_csr_from_triplets", test_csr_from_triplets, devices=devices)
|
|
558
632
|
add_function_test(TestSparse, "test_bsr_from_triplets", test_bsr_from_triplets, devices=devices)
|
|
633
|
+
add_function_test(
|
|
634
|
+
TestSparse,
|
|
635
|
+
"test_bsr_from_triplets_prune_numerical_zeros",
|
|
636
|
+
test_bsr_from_triplets_prune_numerical_zeros,
|
|
637
|
+
devices=devices,
|
|
638
|
+
)
|
|
559
639
|
add_function_test(TestSparse, "test_bsr_get_diag", test_bsr_get_set_diag, devices=devices)
|
|
560
640
|
add_function_test(TestSparse, "test_bsr_split_merge", test_bsr_split_merge, devices=devices)
|
|
561
641
|
add_function_test(TestSparse, "test_bsr_assign_masked", test_bsr_assign_masked, devices=devices)
|
|
642
|
+
add_function_test(TestSparse, "test_bsr_from_triplets_gradient", test_bsr_from_triplets_gradient, devices=devices)
|
|
562
643
|
|
|
563
644
|
add_function_test(TestSparse, "test_csr_transpose", make_test_bsr_transpose((1, 1), wp.float32), devices=devices)
|
|
564
645
|
add_function_test(TestSparse, "test_bsr_transpose_1_3", make_test_bsr_transpose((1, 3), wp.float32), devices=devices)
|