warp-lang 1.7.2__py3-none-manylinux_2_34_aarch64.whl → 1.8.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (180) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +125 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +257 -101
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +657 -223
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  35. warp/examples/optim/example_drone.py +1 -1
  36. warp/examples/sim/example_cloth.py +1 -1
  37. warp/examples/sim/example_cloth_self_contact.py +48 -54
  38. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  39. warp/examples/tile/example_tile_cholesky.py +2 -1
  40. warp/examples/tile/example_tile_convolution.py +1 -1
  41. warp/examples/tile/example_tile_filtering.py +1 -1
  42. warp/examples/tile/example_tile_matmul.py +1 -1
  43. warp/examples/tile/example_tile_mlp.py +2 -0
  44. warp/fabric.py +7 -7
  45. warp/fem/__init__.py +5 -0
  46. warp/fem/adaptivity.py +1 -1
  47. warp/fem/cache.py +152 -63
  48. warp/fem/dirichlet.py +2 -2
  49. warp/fem/domain.py +136 -6
  50. warp/fem/field/field.py +141 -99
  51. warp/fem/field/nodal_field.py +85 -39
  52. warp/fem/field/virtual.py +97 -52
  53. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  54. warp/fem/geometry/closest_point.py +13 -0
  55. warp/fem/geometry/deformed_geometry.py +102 -40
  56. warp/fem/geometry/element.py +56 -2
  57. warp/fem/geometry/geometry.py +323 -22
  58. warp/fem/geometry/grid_2d.py +157 -62
  59. warp/fem/geometry/grid_3d.py +116 -20
  60. warp/fem/geometry/hexmesh.py +86 -20
  61. warp/fem/geometry/nanogrid.py +166 -86
  62. warp/fem/geometry/partition.py +59 -25
  63. warp/fem/geometry/quadmesh.py +86 -135
  64. warp/fem/geometry/tetmesh.py +47 -119
  65. warp/fem/geometry/trimesh.py +77 -270
  66. warp/fem/integrate.py +107 -52
  67. warp/fem/linalg.py +25 -58
  68. warp/fem/operator.py +124 -27
  69. warp/fem/quadrature/pic_quadrature.py +36 -14
  70. warp/fem/quadrature/quadrature.py +40 -16
  71. warp/fem/space/__init__.py +1 -1
  72. warp/fem/space/basis_function_space.py +66 -46
  73. warp/fem/space/basis_space.py +17 -4
  74. warp/fem/space/dof_mapper.py +1 -1
  75. warp/fem/space/function_space.py +2 -2
  76. warp/fem/space/grid_2d_function_space.py +4 -1
  77. warp/fem/space/hexmesh_function_space.py +4 -2
  78. warp/fem/space/nanogrid_function_space.py +3 -1
  79. warp/fem/space/partition.py +11 -2
  80. warp/fem/space/quadmesh_function_space.py +4 -1
  81. warp/fem/space/restriction.py +5 -2
  82. warp/fem/space/shape/__init__.py +10 -8
  83. warp/fem/space/tetmesh_function_space.py +4 -1
  84. warp/fem/space/topology.py +52 -21
  85. warp/fem/space/trimesh_function_space.py +4 -1
  86. warp/fem/utils.py +53 -8
  87. warp/jax.py +1 -2
  88. warp/jax_experimental/ffi.py +12 -17
  89. warp/jax_experimental/xla_ffi.py +37 -24
  90. warp/math.py +171 -1
  91. warp/native/array.h +99 -0
  92. warp/native/builtin.h +174 -31
  93. warp/native/coloring.cpp +1 -1
  94. warp/native/exports.h +118 -63
  95. warp/native/intersect.h +3 -3
  96. warp/native/mat.h +5 -10
  97. warp/native/mathdx.cpp +11 -5
  98. warp/native/matnn.h +1 -123
  99. warp/native/quat.h +28 -4
  100. warp/native/sparse.cpp +121 -258
  101. warp/native/sparse.cu +181 -274
  102. warp/native/spatial.h +305 -17
  103. warp/native/tile.h +583 -72
  104. warp/native/tile_radix_sort.h +1108 -0
  105. warp/native/tile_reduce.h +237 -2
  106. warp/native/tile_scan.h +240 -0
  107. warp/native/tuple.h +189 -0
  108. warp/native/vec.h +6 -16
  109. warp/native/warp.cpp +36 -4
  110. warp/native/warp.cu +574 -51
  111. warp/native/warp.h +47 -74
  112. warp/optim/linear.py +5 -1
  113. warp/paddle.py +7 -8
  114. warp/py.typed +0 -0
  115. warp/render/render_opengl.py +58 -29
  116. warp/render/render_usd.py +124 -61
  117. warp/sim/__init__.py +9 -0
  118. warp/sim/collide.py +252 -78
  119. warp/sim/graph_coloring.py +8 -1
  120. warp/sim/import_mjcf.py +4 -3
  121. warp/sim/import_usd.py +11 -7
  122. warp/sim/integrator.py +5 -2
  123. warp/sim/integrator_euler.py +1 -1
  124. warp/sim/integrator_featherstone.py +1 -1
  125. warp/sim/integrator_vbd.py +751 -320
  126. warp/sim/integrator_xpbd.py +1 -1
  127. warp/sim/model.py +265 -260
  128. warp/sim/utils.py +10 -7
  129. warp/sparse.py +303 -166
  130. warp/tape.py +52 -51
  131. warp/tests/cuda/test_conditional_captures.py +1046 -0
  132. warp/tests/cuda/test_streams.py +1 -1
  133. warp/tests/geometry/test_volume.py +2 -2
  134. warp/tests/interop/test_dlpack.py +9 -9
  135. warp/tests/interop/test_jax.py +0 -1
  136. warp/tests/run_coverage_serial.py +1 -1
  137. warp/tests/sim/disabled_kinematics.py +2 -2
  138. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  139. warp/tests/sim/test_collision.py +159 -51
  140. warp/tests/sim/test_coloring.py +15 -1
  141. warp/tests/test_array.py +254 -2
  142. warp/tests/test_array_reduce.py +2 -2
  143. warp/tests/test_atomic_cas.py +299 -0
  144. warp/tests/test_codegen.py +142 -19
  145. warp/tests/test_conditional.py +47 -1
  146. warp/tests/test_ctypes.py +0 -20
  147. warp/tests/test_devices.py +8 -0
  148. warp/tests/test_fabricarray.py +4 -2
  149. warp/tests/test_fem.py +58 -25
  150. warp/tests/test_func.py +42 -1
  151. warp/tests/test_grad.py +1 -1
  152. warp/tests/test_lerp.py +1 -3
  153. warp/tests/test_map.py +481 -0
  154. warp/tests/test_mat.py +1 -24
  155. warp/tests/test_quat.py +6 -15
  156. warp/tests/test_rounding.py +10 -38
  157. warp/tests/test_runlength_encode.py +7 -7
  158. warp/tests/test_smoothstep.py +1 -1
  159. warp/tests/test_sparse.py +51 -2
  160. warp/tests/test_spatial.py +507 -1
  161. warp/tests/test_struct.py +2 -2
  162. warp/tests/test_tuple.py +265 -0
  163. warp/tests/test_types.py +2 -2
  164. warp/tests/test_utils.py +24 -18
  165. warp/tests/tile/test_tile.py +420 -1
  166. warp/tests/tile/test_tile_mathdx.py +518 -14
  167. warp/tests/tile/test_tile_reduce.py +213 -0
  168. warp/tests/tile/test_tile_shared_memory.py +130 -1
  169. warp/tests/tile/test_tile_sort.py +117 -0
  170. warp/tests/unittest_suites.py +4 -6
  171. warp/types.py +462 -308
  172. warp/utils.py +647 -86
  173. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  174. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +177 -165
  175. warp/stubs.py +0 -3381
  176. warp/tests/sim/test_xpbd.py +0 -399
  177. warp/tests/test_mlp.py +0 -282
  178. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  179. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  180. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/tests/test_fem.py CHANGED
@@ -476,11 +476,8 @@ def _test_geo_cells(
476
476
  wp.atomic_add(cell_measures, s.element_index, fem.measure(domain, s) * s.qp_weight)
477
477
 
478
478
 
479
- @fem.integrand(kernel_options={"enable_backward": False})
480
- def _test_cell_lookup(
481
- s: fem.Sample,
482
- domain: fem.Domain,
483
- ):
479
+ @fem.integrand(kernel_options={"enable_backward": False, "max_unroll": 2})
480
+ def _test_cell_lookup(s: fem.Sample, domain: fem.Domain, cell_filter: wp.array(dtype=int)):
484
481
  pos = domain(s)
485
482
 
486
483
  s_guess = fem.lookup(domain, pos, s)
@@ -491,6 +488,23 @@ def _test_cell_lookup(
491
488
  wp.expect_eq(s_noguess.element_index, s.element_index)
492
489
  wp.expect_near(domain(s_noguess), pos, 0.001)
493
490
 
491
+ # Filtered lookup
492
+ max_dist = 10.0
493
+ filter_target = 1
494
+ s_filter = fem.lookup(domain, pos, max_dist, cell_filter, filter_target)
495
+ wp.expect_eq(s_filter.element_index, 0)
496
+
497
+ if s.element_index != 0:
498
+ # test closest point optimality
499
+ pos_f = domain(s_filter)
500
+ pos_f += 0.1 * (pos - pos_f)
501
+ coord_proj, _sq_dist = fem.element_closest_point(domain, s_filter.element_index, pos_f)
502
+ wp.expect_near(coord_proj, s_filter.element_coords, 0.001)
503
+
504
+ # test that extrapolated coordinates yield bak correct position
505
+ s_filter.element_coords = fem.element_coordinates(domain, s_filter.element_index, pos)
506
+ wp.expect_near(domain(s_filter), pos, 0.001)
507
+
494
508
 
495
509
  @fem.integrand(kernel_options={"enable_backward": False, "max_unroll": 1})
496
510
  def _test_geo_sides(
@@ -520,6 +534,7 @@ def _test_geo_sides(
520
534
 
521
535
  wp.expect_near(coords, inner_side_s.element_coords, 0.0001)
522
536
  wp.expect_near(coords, outer_side_s.element_coords, 0.0001)
537
+ wp.expect_near(coords, fem.element_coordinates(domain, side_index, pos_side), 0.001)
523
538
 
524
539
  area = fem.measure(domain, s)
525
540
  wp.atomic_add(side_measures, side_index, area * s.qp_weight)
@@ -544,7 +559,7 @@ def _test_side_normals(
544
559
  wp.expect_near(F_cross[k], nor[k], 0.0001)
545
560
 
546
561
 
547
- def _launch_test_geometry_kernel(geo: fem.Geometry, device, test_cell_lookup: bool = True):
562
+ def _launch_test_geometry_kernel(geo: fem.Geometry, device):
548
563
  cell_measures = wp.zeros(dtype=float, device=device, shape=geo.cell_count())
549
564
  cell_quadrature = fem.RegularQuadrature(fem.Cells(geo), order=2)
550
565
 
@@ -557,11 +572,11 @@ def _launch_test_geometry_kernel(geo: fem.Geometry, device, test_cell_lookup: bo
557
572
  quadrature=cell_quadrature,
558
573
  values={"cell_measures": cell_measures},
559
574
  )
560
- if test_cell_lookup:
561
- fem.interpolate(
562
- _test_cell_lookup,
563
- quadrature=cell_quadrature,
564
- )
575
+
576
+ cell_filter = np.zeros(geo.cell_count(), dtype=int)
577
+ cell_filter[0] = 1
578
+ cell_filter = wp.array(cell_filter, dtype=int)
579
+ fem.interpolate(_test_cell_lookup, quadrature=cell_quadrature, values={"cell_filter": cell_filter})
565
580
 
566
581
  fem.interpolate(
567
582
  _test_geo_sides,
@@ -637,14 +652,14 @@ def test_quad_mesh(test, device):
637
652
  with wp.ScopedDevice(device):
638
653
  positions, quad_vidx = _gen_quadmesh(N)
639
654
 
640
- geo = fem.Quadmesh2D(quad_vertex_indices=quad_vidx, positions=positions)
655
+ geo = fem.Quadmesh2D(quad_vertex_indices=quad_vidx, positions=positions, build_bvh=True)
641
656
 
642
657
  test.assertEqual(geo.cell_count(), N**2)
643
658
  test.assertEqual(geo.vertex_count(), (N + 1) ** 2)
644
659
  test.assertEqual(geo.side_count(), 2 * (N + 1) * N)
645
660
  test.assertEqual(geo.boundary_side_count(), 4 * N)
646
661
 
647
- side_measures, cell_measures = _launch_test_geometry_kernel(geo, device, test_cell_lookup=False)
662
+ side_measures, cell_measures = _launch_test_geometry_kernel(geo, device)
648
663
 
649
664
  assert_np_equal(side_measures.numpy(), np.full(side_measures.shape, 1.0 / (N)), tol=1.0e-4)
650
665
  assert_np_equal(cell_measures.numpy(), np.full(cell_measures.shape, 1.0 / (N**2)), tol=1.0e-4)
@@ -655,14 +670,14 @@ def test_quad_mesh(test, device):
655
670
  positions = np.hstack((positions, np.ones((positions.shape[0], 1))))
656
671
  positions = wp.array(positions, device=device, dtype=wp.vec3)
657
672
 
658
- geo = fem.Quadmesh3D(quad_vertex_indices=quad_vidx, positions=positions)
673
+ geo = fem.Quadmesh3D(quad_vertex_indices=quad_vidx, positions=positions, build_bvh=True)
659
674
 
660
675
  test.assertEqual(geo.cell_count(), N**2)
661
676
  test.assertEqual(geo.vertex_count(), (N + 1) ** 2)
662
677
  test.assertEqual(geo.side_count(), 2 * (N + 1) * N)
663
678
  test.assertEqual(geo.boundary_side_count(), 4 * N)
664
679
 
665
- side_measures, cell_measures = _launch_test_geometry_kernel(geo, device, test_cell_lookup=False)
680
+ side_measures, cell_measures = _launch_test_geometry_kernel(geo, device)
666
681
 
667
682
  assert_np_equal(side_measures.numpy(), np.full(side_measures.shape, 1.0 / (N)), tol=1.0e-4)
668
683
  assert_np_equal(cell_measures.numpy(), np.full(cell_measures.shape, 1.0 / (N**2)), tol=1.0e-4)
@@ -711,7 +726,7 @@ def test_hex_mesh(test, device):
711
726
  with wp.ScopedDevice(device):
712
727
  positions, tet_vidx = _gen_hexmesh(N)
713
728
 
714
- geo = fem.Hexmesh(hex_vertex_indices=tet_vidx, positions=positions)
729
+ geo = fem.Hexmesh(hex_vertex_indices=tet_vidx, positions=positions, build_bvh=True)
715
730
 
716
731
  test.assertEqual(geo.cell_count(), (N) ** 3)
717
732
  test.assertEqual(geo.vertex_count(), (N + 1) ** 3)
@@ -719,7 +734,7 @@ def test_hex_mesh(test, device):
719
734
  test.assertEqual(geo.boundary_side_count(), 6 * N * N)
720
735
  test.assertEqual(geo.edge_count(), 3 * N * (N + 1) ** 2)
721
736
 
722
- side_measures, cell_measures = _launch_test_geometry_kernel(geo, device, test_cell_lookup=False)
737
+ side_measures, cell_measures = _launch_test_geometry_kernel(geo, device)
723
738
 
724
739
  assert_np_equal(side_measures.numpy(), np.full(side_measures.shape, 1.0 / (N**2)), tol=1.0e-4)
725
740
  assert_np_equal(cell_measures.numpy(), np.full(cell_measures.shape, 1.0 / (N**3)), tol=1.0e-4)
@@ -844,7 +859,8 @@ def test_deformed_geometry(test, device):
844
859
  test.assertEqual(geo.side_count(), 6 * (N + 1) * N**2 + (N**3) * 4)
845
860
  test.assertEqual(geo.boundary_side_count(), 12 * N * N)
846
861
 
847
- side_measures, cell_measures = _launch_test_geometry_kernel(deformed_geo, device, test_cell_lookup=False)
862
+ deformed_geo.build_bvh()
863
+ side_measures, cell_measures = _launch_test_geometry_kernel(deformed_geo, device)
848
864
 
849
865
  test.assertAlmostEqual(
850
866
  np.sum(cell_measures.numpy()), scale**3, places=4, msg=f"cell_measures = {cell_measures.numpy()}"
@@ -1497,7 +1513,7 @@ def test_point_basis(test, device):
1497
1513
  point_test = fem.make_test(point_space, domain=domain)
1498
1514
 
1499
1515
  # Sample at particle positions
1500
- ones = fem.integrate(linear_form, fields={"u": point_test}, nodal=True)
1516
+ ones = fem.integrate(linear_form, fields={"u": point_test}, assembly="nodal")
1501
1517
  test.assertAlmostEqual(np.sum(ones.numpy()), 1.0, places=5)
1502
1518
 
1503
1519
  # Sampling outside of particle positions
@@ -1613,12 +1629,23 @@ def test_particle_quadratures(test, device):
1613
1629
  assert_np_equal(measures.grad.numpy(), np.full(3, 4.0)) # == 1.0 / cell_area
1614
1630
 
1615
1631
 
1616
- @fem.integrand
1617
- def _value_at_node(s: fem.Sample, f: fem.Field, values: wp.array(dtype=float)):
1632
+ @fem.integrand(kernel_options={"enable_backward": False})
1633
+ def _value_at_node(domain: fem.Domain, s: fem.Sample, f: fem.Field, values: wp.array(dtype=float)):
1634
+ # lookup at node is ambiguous, check that partition_lookup retains sample on current partition
1635
+ s_partition = fem.partition_lookup(domain, domain(s))
1636
+ wp.expect_eq(s.element_index, s_partition.element_index)
1637
+ wp.expect_neq(fem.operator.element_partition_index(domain, s.element_index), fem.NULL_ELEMENT_INDEX)
1638
+
1618
1639
  node_index = fem.operator.node_partition_index(f, s.qp_index)
1619
1640
  return values[node_index]
1620
1641
 
1621
1642
 
1643
+ @fem.integrand(kernel_options={"enable_backward": False})
1644
+ def _test_node_index(s: fem.Sample, u: fem.Field):
1645
+ wp.expect_eq(fem.node_index(u, s), s.qp_index)
1646
+ return 0.0
1647
+
1648
+
1622
1649
  def test_nodal_quadrature(test, device):
1623
1650
  geo = fem.Grid2D(res=wp.vec2i(2))
1624
1651
 
@@ -1635,15 +1662,18 @@ def test_nodal_quadrature(test, device):
1635
1662
 
1636
1663
  # test accessing data associated to a given node
1637
1664
 
1638
- piecewise_constant_space = fem.make_polynomial_space(geo, degree=0)
1639
- geo_partition = fem.LinearGeometryPartition(geo, 3, 4)
1640
- space_partition = fem.make_space_partition(piecewise_constant_space, geo_partition)
1665
+ piecewise_constant_space = fem.make_polynomial_space(geo, degree=1)
1666
+ geo_partition = fem.LinearGeometryPartition(geo, 2, 4)
1667
+ assert geo_partition.cell_count() == 1
1668
+
1669
+ space_partition = fem.make_space_partition(piecewise_constant_space, geo_partition, with_halo=False)
1670
+
1641
1671
  field = fem.make_discrete_field(piecewise_constant_space, space_partition=space_partition)
1642
1672
 
1643
1673
  partition_domain = fem.Cells(geo_partition)
1644
1674
  partition_nodal_quadrature = fem.NodalQuadrature(partition_domain, piecewise_constant_space)
1645
1675
 
1646
- partition_node_values = wp.array([5.0], dtype=float)
1676
+ partition_node_values = wp.full(value=5.0, shape=space_partition.node_count(), dtype=float)
1647
1677
  val = fem.integrate(
1648
1678
  _value_at_node,
1649
1679
  quadrature=partition_nodal_quadrature,
@@ -1652,6 +1682,9 @@ def test_nodal_quadrature(test, device):
1652
1682
  )
1653
1683
  test.assertAlmostEqual(val, 5.0 / geo.cell_count(), places=5)
1654
1684
 
1685
+ u_test = fem.make_test(space)
1686
+ fem.integrate(_test_node_index, assembly="nodal", fields={"u": u_test})
1687
+
1655
1688
 
1656
1689
  @wp.func
1657
1690
  def aniso_bicubic_fn(x: wp.vec2, scale: wp.vec2):
warp/tests/test_func.py CHANGED
@@ -259,7 +259,25 @@ def test_return_annotation_none() -> None:
259
259
  user_func_return_none()
260
260
 
261
261
 
262
- devices = get_test_devices()
262
+ @wp.func
263
+ def divide_by_zero(x: float):
264
+ return x / 0.0
265
+
266
+
267
+ @wp.func
268
+ def normalize_vector(vec_a: wp.vec3):
269
+ return wp.normalize(vec_a)
270
+
271
+
272
+ # This pair is to test the situation where one overload throws an error, but a second one works.
273
+ @wp.func
274
+ def divide_by_zero_overload(x: wp.float32):
275
+ return x / 0
276
+
277
+
278
+ @wp.func
279
+ def divide_by_zero_overload(x: wp.float64):
280
+ return wp.div(x, 0.0)
263
281
 
264
282
 
265
283
  class TestFunc(unittest.TestCase):
@@ -425,6 +443,29 @@ class TestFunc(unittest.TestCase):
425
443
  ):
426
444
  a * b
427
445
 
446
+ def test_cpython_call_user_function_with_error(self):
447
+ # Actually the following also includes a ZeroDivisionError in the message due to exception chaining,
448
+ # but I don't know how to test for that.
449
+ with self.assertRaisesRegex(
450
+ RuntimeError,
451
+ "Error calling function 'divide_by_zero'. No version succeeded. "
452
+ "See above for the error from the last version that was tried.",
453
+ ):
454
+ divide_by_zero(1.0)
455
+
456
+ def test_cpython_call_user_function_with_overloads(self):
457
+ self.assertEqual(divide_by_zero_overload(1.0), math.inf)
458
+
459
+ def test_cpython_call_user_function_with_wrong_argument_types(self):
460
+ with self.assertRaisesRegex(
461
+ RuntimeError,
462
+ "Error calling function 'normalize_vector'. No version succeeded. "
463
+ "See above for the error from the last version that was tried.",
464
+ ):
465
+ normalize_vector(1.0)
466
+
467
+
468
+ devices = get_test_devices()
428
469
 
429
470
  add_kernel_test(TestFunc, kernel=test_overload_func, name="test_overload_func", dim=1, devices=devices)
430
471
  add_function_test(TestFunc, func=test_return_func, name="test_return_func", devices=devices)
warp/tests/test_grad.py CHANGED
@@ -518,7 +518,7 @@ def test_mesh_grad(test, device):
518
518
  pos_np[i, j] += eps
519
519
  fd_grad[i, j] = (f1 - f2) / (2 * eps)
520
520
 
521
- assert np.allclose(ad_grad, fd_grad, atol=1e-3)
521
+ np.testing.assert_allclose(ad_grad, fd_grad, atol=1e-3)
522
522
 
523
523
 
524
524
  @wp.func
warp/tests/test_lerp.py CHANGED
@@ -119,9 +119,7 @@ def test_lerp(test, device):
119
119
  a = wp.array([test_data.a], dtype=data_type, device=device, requires_grad=True)
120
120
  b = wp.array([test_data.b], dtype=data_type, device=device, requires_grad=True)
121
121
  t = wp.array([test_data.t], dtype=float, device=device, requires_grad=True)
122
- out = wp.array(
123
- [0] * wp.types.type_length(data_type), dtype=data_type, device=device, requires_grad=True
124
- )
122
+ out = wp.array([0] * wp.types.type_size(data_type), dtype=data_type, device=device, requires_grad=True)
125
123
 
126
124
  with wp.Tape() as tape:
127
125
  wp.launch(kernel, dim=1, inputs=[a, b, t, out], device=device)