warp-lang 1.3.3__py3-none-macosx_10_13_universal2.whl → 1.4.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +6 -0
- warp/autograd.py +59 -6
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build_dll.py +8 -10
- warp/builtins.py +126 -4
- warp/codegen.py +435 -53
- warp/config.py +1 -1
- warp/context.py +678 -403
- warp/dlpack.py +2 -0
- warp/examples/benchmarks/benchmark_cloth.py +10 -0
- warp/examples/core/example_render_opengl.py +12 -10
- warp/examples/fem/example_adaptive_grid.py +251 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_diffusion_3d.py +2 -2
- warp/examples/fem/example_magnetostatics.py +1 -1
- warp/examples/fem/example_streamlines.py +1 -0
- warp/examples/fem/utils.py +23 -4
- warp/examples/sim/example_cloth.py +50 -6
- warp/fem/__init__.py +2 -0
- warp/fem/adaptivity.py +493 -0
- warp/fem/field/field.py +2 -1
- warp/fem/field/nodal_field.py +18 -26
- warp/fem/field/test.py +4 -4
- warp/fem/field/trial.py +4 -4
- warp/fem/geometry/__init__.py +1 -0
- warp/fem/geometry/adaptive_nanogrid.py +843 -0
- warp/fem/geometry/nanogrid.py +55 -28
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/nanogrid_function_space.py +69 -35
- warp/fem/utils.py +113 -107
- warp/jax_experimental.py +28 -15
- warp/native/array.h +0 -1
- warp/native/builtin.h +103 -6
- warp/native/bvh.cu +2 -0
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/error.cpp +4 -2
- warp/native/exports.h +99 -17
- warp/native/mat.h +97 -0
- warp/native/mesh.cpp +36 -0
- warp/native/mesh.cu +51 -0
- warp/native/mesh.h +1 -0
- warp/native/quat.h +43 -0
- warp/native/spatial.h +6 -0
- warp/native/vec.h +74 -0
- warp/native/warp.cpp +2 -1
- warp/native/warp.cu +10 -3
- warp/native/warp.h +8 -1
- warp/paddle.py +382 -0
- warp/sim/__init__.py +1 -0
- warp/sim/collide.py +519 -0
- warp/sim/integrator_euler.py +18 -5
- warp/sim/integrator_featherstone.py +5 -5
- warp/sim/integrator_vbd.py +1026 -0
- warp/sim/model.py +49 -23
- warp/stubs.py +459 -0
- warp/tape.py +2 -0
- warp/tests/aux_test_dependent.py +1 -0
- warp/tests/aux_test_name_clash1.py +32 -0
- warp/tests/aux_test_name_clash2.py +32 -0
- warp/tests/aux_test_square.py +1 -0
- warp/tests/test_array.py +188 -0
- warp/tests/test_async.py +3 -3
- warp/tests/test_atomic.py +6 -0
- warp/tests/test_closest_point_edge_edge.py +93 -1
- warp/tests/test_codegen.py +62 -15
- warp/tests/test_codegen_instancing.py +1457 -0
- warp/tests/test_collision.py +486 -0
- warp/tests/test_compile_consts.py +3 -28
- warp/tests/test_dlpack.py +170 -0
- warp/tests/test_examples.py +22 -8
- warp/tests/test_fast_math.py +10 -4
- warp/tests/test_fem.py +64 -0
- warp/tests/test_func.py +46 -0
- warp/tests/test_implicit_init.py +49 -0
- warp/tests/test_jax.py +58 -0
- warp/tests/test_mat.py +84 -0
- warp/tests/test_mesh_query_point.py +188 -0
- warp/tests/test_module_hashing.py +40 -0
- warp/tests/test_multigpu.py +3 -3
- warp/tests/test_overwrite.py +8 -0
- warp/tests/test_paddle.py +852 -0
- warp/tests/test_print.py +89 -0
- warp/tests/test_quat.py +111 -0
- warp/tests/test_reload.py +31 -1
- warp/tests/test_scalar_ops.py +2 -0
- warp/tests/test_static.py +412 -0
- warp/tests/test_streams.py +64 -3
- warp/tests/test_struct.py +4 -4
- warp/tests/test_torch.py +24 -0
- warp/tests/test_triangle_closest_point.py +137 -0
- warp/tests/test_types.py +1 -1
- warp/tests/test_vbd.py +386 -0
- warp/tests/test_vec.py +143 -0
- warp/tests/test_vec_scalar_ops.py +139 -0
- warp/tests/unittest_suites.py +12 -0
- warp/tests/unittest_utils.py +9 -5
- warp/thirdparty/dlpack.py +3 -1
- warp/types.py +150 -28
- warp/utils.py +37 -14
- {warp_lang-1.3.3.dist-info → warp_lang-1.4.0.dist-info}/METADATA +10 -8
- {warp_lang-1.3.3.dist-info → warp_lang-1.4.0.dist-info}/RECORD +106 -94
- warp/tests/test_point_triangle_closest_point.py +0 -143
- {warp_lang-1.3.3.dist-info → warp_lang-1.4.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.3.3.dist-info → warp_lang-1.4.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.3.3.dist-info → warp_lang-1.4.0.dist-info}/top_level.txt +0 -0
|
@@ -11,6 +11,7 @@ import unittest
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
|
|
13
13
|
import warp as wp
|
|
14
|
+
import warp.examples
|
|
14
15
|
from warp.tests.unittest_utils import *
|
|
15
16
|
|
|
16
17
|
|
|
@@ -654,6 +655,192 @@ def test_mesh_query_furthest_point(test, device):
|
|
|
654
655
|
assert_np_equal(dist_query.numpy(), dist_brute.numpy(), tol=1.0e-3)
|
|
655
656
|
|
|
656
657
|
|
|
658
|
+
@wp.func
|
|
659
|
+
def triangle_closest_point_for_test(a: wp.vec3, b: wp.vec3, c: wp.vec3, p: wp.vec3):
|
|
660
|
+
ab = b - a
|
|
661
|
+
ac = c - a
|
|
662
|
+
ap = p - a
|
|
663
|
+
|
|
664
|
+
d1 = wp.dot(ab, ap)
|
|
665
|
+
d2 = wp.dot(ac, ap)
|
|
666
|
+
if d1 <= 0.0 and d2 <= 0.0:
|
|
667
|
+
bary = wp.vec3(1.0, 0.0, 0.0)
|
|
668
|
+
return a, bary
|
|
669
|
+
|
|
670
|
+
bp = p - b
|
|
671
|
+
d3 = wp.dot(ab, bp)
|
|
672
|
+
d4 = wp.dot(ac, bp)
|
|
673
|
+
if d3 >= 0.0 and d4 <= d3:
|
|
674
|
+
bary = wp.vec3(0.0, 1.0, 0.0)
|
|
675
|
+
return b, bary
|
|
676
|
+
|
|
677
|
+
cp = p - c
|
|
678
|
+
d5 = wp.dot(ab, cp)
|
|
679
|
+
d6 = wp.dot(ac, cp)
|
|
680
|
+
if d6 >= 0.0 and d5 <= d6:
|
|
681
|
+
bary = wp.vec3(0.0, 0.0, 1.0)
|
|
682
|
+
return c, bary
|
|
683
|
+
|
|
684
|
+
vc = d1 * d4 - d3 * d2
|
|
685
|
+
if vc <= 0.0 and d1 >= 0.0 and d3 <= 0.0:
|
|
686
|
+
v = d1 / (d1 - d3)
|
|
687
|
+
bary = wp.vec3(1.0 - v, v, 0.0)
|
|
688
|
+
return a + v * ab, bary
|
|
689
|
+
|
|
690
|
+
vb = d5 * d2 - d1 * d6
|
|
691
|
+
if vb <= 0.0 and d2 >= 0.0 and d6 <= 0.0:
|
|
692
|
+
v = d2 / (d2 - d6)
|
|
693
|
+
bary = wp.vec3(1.0 - v, 0.0, v)
|
|
694
|
+
return a + v * ac, bary
|
|
695
|
+
|
|
696
|
+
va = d3 * d6 - d5 * d4
|
|
697
|
+
if va <= 0.0 and (d4 - d3) >= 0.0 and (d5 - d6) >= 0.0:
|
|
698
|
+
v = (d4 - d3) / ((d4 - d3) + (d5 - d6))
|
|
699
|
+
bary = wp.vec3(0.0, 1.0 - v, v)
|
|
700
|
+
return b + v * (c - b), bary
|
|
701
|
+
|
|
702
|
+
denom = 1.0 / (va + vb + vc)
|
|
703
|
+
v = vb * denom
|
|
704
|
+
w = vc * denom
|
|
705
|
+
bary = wp.vec3(1.0 - v - w, v, w)
|
|
706
|
+
return a + v * ab + w * ac, bary
|
|
707
|
+
|
|
708
|
+
|
|
709
|
+
def load_mesh():
|
|
710
|
+
from pxr import Usd, UsdGeom
|
|
711
|
+
|
|
712
|
+
usd_stage = Usd.Stage.Open(os.path.join(wp.examples.get_asset_directory(), "bunny.usd"))
|
|
713
|
+
usd_geom = UsdGeom.Mesh(usd_stage.GetPrimAtPath("/root/bunny"))
|
|
714
|
+
|
|
715
|
+
vertices = np.array(usd_geom.GetPointsAttr().Get())
|
|
716
|
+
faces = np.array(usd_geom.GetFaceVertexIndicesAttr().Get())
|
|
717
|
+
|
|
718
|
+
return vertices, faces
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
@wp.kernel
|
|
722
|
+
def point_query_aabb_and_closest(
|
|
723
|
+
query_radius: float,
|
|
724
|
+
mesh_id: wp.uint64,
|
|
725
|
+
pts: wp.array(dtype=wp.vec3),
|
|
726
|
+
pos: wp.array(dtype=wp.vec3),
|
|
727
|
+
tri_indices: wp.array(dtype=wp.int32, ndim=2),
|
|
728
|
+
query_results_num_collisions: wp.array(dtype=wp.int32),
|
|
729
|
+
query_results_min_dist: wp.array(dtype=float),
|
|
730
|
+
query_results_closest_point_velocity: wp.array(dtype=wp.vec3),
|
|
731
|
+
):
|
|
732
|
+
p_index = wp.tid()
|
|
733
|
+
p = pts[p_index]
|
|
734
|
+
|
|
735
|
+
lower = wp.vec3(p[0] - query_radius, p[1] - query_radius, p[2] - query_radius)
|
|
736
|
+
upper = wp.vec3(p[0] + query_radius, p[1] + query_radius, p[2] + query_radius)
|
|
737
|
+
|
|
738
|
+
closest_query = wp.mesh_query_point_no_sign(mesh_id, p, query_radius)
|
|
739
|
+
if closest_query.result:
|
|
740
|
+
closest_p = wp.mesh_eval_position(mesh_id, closest_query.face, closest_query.u, closest_query.v)
|
|
741
|
+
closest_p_vel = wp.mesh_eval_velocity(mesh_id, closest_query.face, closest_query.u, closest_query.v)
|
|
742
|
+
|
|
743
|
+
query_results_min_dist[p_index] = wp.length(closest_p - p)
|
|
744
|
+
query_results_closest_point_velocity[p_index] = closest_p_vel
|
|
745
|
+
|
|
746
|
+
query = wp.mesh_query_aabb(mesh_id, lower, upper)
|
|
747
|
+
|
|
748
|
+
tri_index = wp.int32(0)
|
|
749
|
+
num_collisions = wp.int32(0)
|
|
750
|
+
min_dis_to_tris = query_radius
|
|
751
|
+
while wp.mesh_query_aabb_next(query, tri_index):
|
|
752
|
+
t1 = tri_indices[tri_index, 0]
|
|
753
|
+
t2 = tri_indices[tri_index, 1]
|
|
754
|
+
t3 = tri_indices[tri_index, 2]
|
|
755
|
+
|
|
756
|
+
u1 = pos[t1]
|
|
757
|
+
u2 = pos[t2]
|
|
758
|
+
u3 = pos[t3]
|
|
759
|
+
|
|
760
|
+
closest_p1, barycentric1 = triangle_closest_point_for_test(u1, u2, u3, p)
|
|
761
|
+
|
|
762
|
+
dis = wp.length(closest_p1 - p)
|
|
763
|
+
|
|
764
|
+
if dis < query_radius:
|
|
765
|
+
num_collisions = num_collisions + 1
|
|
766
|
+
|
|
767
|
+
query_results_num_collisions[p_index] = num_collisions
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
@unittest.skipUnless(USD_AVAILABLE, "Requires usd-core")
|
|
771
|
+
def test_set_mesh_points(test, device):
|
|
772
|
+
vs, fs = load_mesh()
|
|
773
|
+
|
|
774
|
+
vertices1 = wp.array(vs, dtype=wp.vec3, device=device)
|
|
775
|
+
velocities1_np = np.random.randn(vertices1.shape[0], 3)
|
|
776
|
+
velocities1 = wp.array(velocities1_np, dtype=wp.vec3, device=device)
|
|
777
|
+
|
|
778
|
+
faces = wp.array(fs, dtype=wp.int32, device=device)
|
|
779
|
+
mesh = wp.Mesh(vertices1, faces, velocities=velocities1)
|
|
780
|
+
fs_2D = faces.reshape((-1, 3))
|
|
781
|
+
np.random.seed(12345)
|
|
782
|
+
n = 1000
|
|
783
|
+
query_radius = 0.2
|
|
784
|
+
|
|
785
|
+
pts1 = wp.array(np.random.randn(n, 3), dtype=wp.vec3, device=device)
|
|
786
|
+
|
|
787
|
+
query_results_num_cols1 = wp.zeros(n, dtype=wp.int32, device=device)
|
|
788
|
+
query_results_min_dist1 = wp.zeros(n, dtype=float, device=device)
|
|
789
|
+
query_results_closest_point_velocity1 = wp.zeros(n, dtype=wp.vec3, device=device)
|
|
790
|
+
|
|
791
|
+
wp.launch(
|
|
792
|
+
kernel=point_query_aabb_and_closest,
|
|
793
|
+
inputs=[
|
|
794
|
+
query_radius,
|
|
795
|
+
mesh.id,
|
|
796
|
+
pts1,
|
|
797
|
+
vertices1,
|
|
798
|
+
fs_2D,
|
|
799
|
+
query_results_num_cols1,
|
|
800
|
+
query_results_min_dist1,
|
|
801
|
+
query_results_closest_point_velocity1,
|
|
802
|
+
],
|
|
803
|
+
dim=n,
|
|
804
|
+
device=device,
|
|
805
|
+
)
|
|
806
|
+
|
|
807
|
+
shift = np.random.randn(3)
|
|
808
|
+
|
|
809
|
+
vs_higher = vs + shift
|
|
810
|
+
vertices2 = wp.array(vs_higher, dtype=wp.vec3, device=device)
|
|
811
|
+
|
|
812
|
+
velocities2_np = velocities1_np + shift[None, ...]
|
|
813
|
+
velocities2 = wp.array(velocities2_np, dtype=wp.vec3, device=device)
|
|
814
|
+
|
|
815
|
+
pts2 = wp.array(pts1.numpy() + shift, dtype=wp.vec3, device=device)
|
|
816
|
+
|
|
817
|
+
mesh.points = vertices2
|
|
818
|
+
mesh.velocities = velocities2
|
|
819
|
+
|
|
820
|
+
query_results_num_cols2 = wp.zeros(n, dtype=wp.int32, device=device)
|
|
821
|
+
query_results_min_dist2 = wp.zeros(n, dtype=float, device=device)
|
|
822
|
+
query_results_closest_point_velocity2 = wp.array([shift for i in range(n)], dtype=wp.vec3, device=device)
|
|
823
|
+
|
|
824
|
+
wp.launch(
|
|
825
|
+
kernel=point_query_aabb_and_closest,
|
|
826
|
+
inputs=[
|
|
827
|
+
query_radius,
|
|
828
|
+
mesh.id,
|
|
829
|
+
pts2,
|
|
830
|
+
vertices2,
|
|
831
|
+
fs_2D,
|
|
832
|
+
query_results_num_cols2,
|
|
833
|
+
query_results_min_dist2,
|
|
834
|
+
query_results_closest_point_velocity2,
|
|
835
|
+
],
|
|
836
|
+
dim=n,
|
|
837
|
+
device=device,
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
test.assertTrue((query_results_num_cols1.numpy() == query_results_num_cols2.numpy()).all())
|
|
841
|
+
test.assertTrue(((query_results_min_dist1.numpy() - query_results_min_dist2.numpy()) < 1e-5).all())
|
|
842
|
+
|
|
843
|
+
|
|
657
844
|
devices = get_test_devices()
|
|
658
845
|
|
|
659
846
|
|
|
@@ -684,6 +871,7 @@ class TestMeshQueryPoint(unittest.TestCase):
|
|
|
684
871
|
add_function_test(TestMeshQueryPoint, "test_mesh_query_point", test_mesh_query_point, devices=devices)
|
|
685
872
|
add_function_test(TestMeshQueryPoint, "test_mesh_query_furthest_point", test_mesh_query_furthest_point, devices=devices)
|
|
686
873
|
add_function_test(TestMeshQueryPoint, "test_adj_mesh_query_point", test_adj_mesh_query_point, devices=devices)
|
|
874
|
+
add_function_test(TestMeshQueryPoint, "test_set_mesh_points", test_set_mesh_points, devices=devices)
|
|
687
875
|
|
|
688
876
|
if __name__ == "__main__":
|
|
689
877
|
wp.clear_kernel_cache()
|
|
@@ -25,6 +25,11 @@ def fn():
|
|
|
25
25
|
@wp.func
|
|
26
26
|
def fn(value: int):
|
|
27
27
|
wp.print(value)
|
|
28
|
+
|
|
29
|
+
@wp.kernel
|
|
30
|
+
def k():
|
|
31
|
+
print(fn())
|
|
32
|
+
print(fn(99))
|
|
28
33
|
"""
|
|
29
34
|
|
|
30
35
|
# should be same hash as FUNC_OVERLOAD_1
|
|
@@ -38,6 +43,11 @@ def fn():
|
|
|
38
43
|
@wp.func
|
|
39
44
|
def fn(value: int):
|
|
40
45
|
wp.print(value)
|
|
46
|
+
|
|
47
|
+
@wp.kernel
|
|
48
|
+
def k():
|
|
49
|
+
print(fn())
|
|
50
|
+
print(fn(99))
|
|
41
51
|
"""
|
|
42
52
|
|
|
43
53
|
# should be different hash than FUNC_OVERLOAD_1 (first overload is different)
|
|
@@ -51,6 +61,11 @@ def fn():
|
|
|
51
61
|
@wp.func
|
|
52
62
|
def fn(value: int):
|
|
53
63
|
wp.print(value)
|
|
64
|
+
|
|
65
|
+
@wp.kernel
|
|
66
|
+
def k():
|
|
67
|
+
print(fn())
|
|
68
|
+
print(fn(99))
|
|
54
69
|
"""
|
|
55
70
|
|
|
56
71
|
# should be different hash than FUNC_OVERLOAD_1 (second overload is different)
|
|
@@ -64,6 +79,11 @@ def fn():
|
|
|
64
79
|
@wp.func
|
|
65
80
|
def fn(value: int):
|
|
66
81
|
wp.print(value + 1)
|
|
82
|
+
|
|
83
|
+
@wp.kernel
|
|
84
|
+
def k():
|
|
85
|
+
print(fn())
|
|
86
|
+
print(fn(99))
|
|
67
87
|
"""
|
|
68
88
|
|
|
69
89
|
FUNC_GENERIC_1 = """# -*- coding: utf-8 -*-
|
|
@@ -78,6 +98,11 @@ def generic_fn(x: Any):
|
|
|
78
98
|
@wp.func
|
|
79
99
|
def generic_fn(x: Any, y: Any):
|
|
80
100
|
return x * y
|
|
101
|
+
|
|
102
|
+
@wp.kernel
|
|
103
|
+
def k():
|
|
104
|
+
print(generic_fn(17))
|
|
105
|
+
print(generic_fn(17, 42))
|
|
81
106
|
"""
|
|
82
107
|
|
|
83
108
|
# should be same hash as FUNC_GENERIC_1
|
|
@@ -93,6 +118,11 @@ def generic_fn(x: Any):
|
|
|
93
118
|
@wp.func
|
|
94
119
|
def generic_fn(x: Any, y: Any):
|
|
95
120
|
return x * y
|
|
121
|
+
|
|
122
|
+
@wp.kernel
|
|
123
|
+
def k():
|
|
124
|
+
print(generic_fn(17))
|
|
125
|
+
print(generic_fn(17, 42))
|
|
96
126
|
"""
|
|
97
127
|
|
|
98
128
|
# should be different hash than FUNC_GENERIC_1 (first overload is different)
|
|
@@ -108,6 +138,11 @@ def generic_fn(x: Any):
|
|
|
108
138
|
@wp.func
|
|
109
139
|
def generic_fn(x: Any, y: Any):
|
|
110
140
|
return x * y
|
|
141
|
+
|
|
142
|
+
@wp.kernel
|
|
143
|
+
def k():
|
|
144
|
+
print(generic_fn(17))
|
|
145
|
+
print(generic_fn(17, 42))
|
|
111
146
|
"""
|
|
112
147
|
|
|
113
148
|
# should be different hash than FUNC_GENERIC_1 (second overload is different)
|
|
@@ -123,6 +158,11 @@ def generic_fn(x: Any):
|
|
|
123
158
|
@wp.func
|
|
124
159
|
def generic_fn(x: Any, y: Any):
|
|
125
160
|
return x + y
|
|
161
|
+
|
|
162
|
+
@wp.kernel
|
|
163
|
+
def k():
|
|
164
|
+
print(generic_fn(17))
|
|
165
|
+
print(generic_fn(17, 42))
|
|
126
166
|
"""
|
|
127
167
|
|
|
128
168
|
|
warp/tests/test_multigpu.py
CHANGED
|
@@ -11,7 +11,7 @@ import numpy as np
|
|
|
11
11
|
|
|
12
12
|
import warp as wp
|
|
13
13
|
from warp.tests.unittest_utils import *
|
|
14
|
-
from warp.utils import
|
|
14
|
+
from warp.utils import check_p2p
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
@wp.kernel
|
|
@@ -105,7 +105,7 @@ class TestMultiGPU(unittest.TestCase):
|
|
|
105
105
|
assert wp.get_cuda_device() == initial_cuda_device
|
|
106
106
|
|
|
107
107
|
@unittest.skipUnless(len(wp.get_cuda_devices()) > 1, "Requires at least two CUDA devices")
|
|
108
|
-
@unittest.skipUnless(
|
|
108
|
+
@unittest.skipUnless(check_p2p(), "Peer-to-Peer transfers not supported")
|
|
109
109
|
def test_multigpu_pingpong(self):
|
|
110
110
|
n = 1024 * 1024
|
|
111
111
|
|
|
@@ -129,7 +129,7 @@ class TestMultiGPU(unittest.TestCase):
|
|
|
129
129
|
assert_np_equal(a1.numpy(), expected)
|
|
130
130
|
|
|
131
131
|
@unittest.skipUnless(len(wp.get_cuda_devices()) > 1, "Requires at least two CUDA devices")
|
|
132
|
-
@unittest.skipUnless(
|
|
132
|
+
@unittest.skipUnless(check_p2p(), "Peer-to-Peer transfers not supported")
|
|
133
133
|
def test_multigpu_pingpong_streams(self):
|
|
134
134
|
n = 1024 * 1024
|
|
135
135
|
|
warp/tests/test_overwrite.py
CHANGED
|
@@ -7,6 +7,8 @@ import numpy as np
|
|
|
7
7
|
import warp as wp
|
|
8
8
|
from warp.tests.unittest_utils import *
|
|
9
9
|
|
|
10
|
+
wp.init() # For wp.context.runtime.core.is_cutlass_enabled()
|
|
11
|
+
|
|
10
12
|
# kernels are defined in the global scope, to ensure wp.Kernel objects are not GC'ed in the MGPU case
|
|
11
13
|
# kernel args are assigned array modes during codegen, so wp.Kernel objects generated during codegen
|
|
12
14
|
# must be preserved for overwrite tracking to function
|
|
@@ -362,6 +364,9 @@ def test_copy(test, device):
|
|
|
362
364
|
|
|
363
365
|
# wp.matmul uses wp.record_func. Ensure array modes are propagated correctly.
|
|
364
366
|
def test_matmul(test, device):
|
|
367
|
+
if device.is_cuda and not wp.context.runtime.core.is_cutlass_enabled():
|
|
368
|
+
test.skipTest("Warp was not built with CUTLASS support")
|
|
369
|
+
|
|
365
370
|
saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access
|
|
366
371
|
try:
|
|
367
372
|
wp.config.verify_autograd_array_access = True
|
|
@@ -387,6 +392,9 @@ def test_matmul(test, device):
|
|
|
387
392
|
|
|
388
393
|
# wp.batched_matmul uses wp.record_func. Ensure array modes are propagated correctly.
|
|
389
394
|
def test_batched_matmul(test, device):
|
|
395
|
+
if device.is_cuda and not wp.context.runtime.core.is_cutlass_enabled():
|
|
396
|
+
test.skipTest("Warp was not built with CUTLASS support")
|
|
397
|
+
|
|
390
398
|
saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access
|
|
391
399
|
try:
|
|
392
400
|
wp.config.verify_autograd_array_access = True
|