warp-lang 1.3.2__py3-none-manylinux2014_aarch64.whl → 1.4.0__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (107) hide show
  1. warp/__init__.py +6 -0
  2. warp/autograd.py +59 -6
  3. warp/bin/warp.so +0 -0
  4. warp/build_dll.py +8 -10
  5. warp/builtins.py +126 -4
  6. warp/codegen.py +435 -53
  7. warp/config.py +1 -1
  8. warp/context.py +678 -403
  9. warp/dlpack.py +2 -0
  10. warp/examples/benchmarks/benchmark_cloth.py +10 -0
  11. warp/examples/core/example_render_opengl.py +12 -10
  12. warp/examples/fem/example_adaptive_grid.py +251 -0
  13. warp/examples/fem/example_apic_fluid.py +1 -1
  14. warp/examples/fem/example_diffusion_3d.py +2 -2
  15. warp/examples/fem/example_magnetostatics.py +1 -1
  16. warp/examples/fem/example_streamlines.py +1 -0
  17. warp/examples/fem/utils.py +23 -4
  18. warp/examples/sim/example_cloth.py +50 -6
  19. warp/fem/__init__.py +2 -0
  20. warp/fem/adaptivity.py +493 -0
  21. warp/fem/field/field.py +2 -1
  22. warp/fem/field/nodal_field.py +18 -26
  23. warp/fem/field/test.py +4 -4
  24. warp/fem/field/trial.py +4 -4
  25. warp/fem/geometry/__init__.py +1 -0
  26. warp/fem/geometry/adaptive_nanogrid.py +843 -0
  27. warp/fem/geometry/nanogrid.py +55 -28
  28. warp/fem/space/__init__.py +1 -1
  29. warp/fem/space/nanogrid_function_space.py +69 -35
  30. warp/fem/utils.py +113 -107
  31. warp/jax_experimental.py +28 -15
  32. warp/native/array.h +0 -1
  33. warp/native/builtin.h +103 -6
  34. warp/native/bvh.cu +2 -0
  35. warp/native/cuda_util.cpp +14 -0
  36. warp/native/cuda_util.h +2 -0
  37. warp/native/error.cpp +4 -2
  38. warp/native/exports.h +99 -17
  39. warp/native/mat.h +97 -0
  40. warp/native/mesh.cpp +36 -0
  41. warp/native/mesh.cu +51 -0
  42. warp/native/mesh.h +1 -0
  43. warp/native/quat.h +43 -0
  44. warp/native/spatial.h +6 -0
  45. warp/native/vec.h +74 -0
  46. warp/native/warp.cpp +2 -1
  47. warp/native/warp.cu +10 -3
  48. warp/native/warp.h +8 -1
  49. warp/paddle.py +382 -0
  50. warp/sim/__init__.py +1 -0
  51. warp/sim/collide.py +519 -0
  52. warp/sim/integrator_euler.py +18 -5
  53. warp/sim/integrator_featherstone.py +5 -5
  54. warp/sim/integrator_vbd.py +1026 -0
  55. warp/sim/model.py +49 -23
  56. warp/stubs.py +459 -0
  57. warp/tape.py +2 -0
  58. warp/tests/aux_test_dependent.py +1 -0
  59. warp/tests/aux_test_name_clash1.py +32 -0
  60. warp/tests/aux_test_name_clash2.py +32 -0
  61. warp/tests/aux_test_square.py +1 -0
  62. warp/tests/test_array.py +222 -0
  63. warp/tests/test_async.py +3 -3
  64. warp/tests/test_atomic.py +6 -0
  65. warp/tests/test_closest_point_edge_edge.py +93 -1
  66. warp/tests/test_codegen.py +62 -15
  67. warp/tests/test_codegen_instancing.py +1457 -0
  68. warp/tests/test_collision.py +486 -0
  69. warp/tests/test_compile_consts.py +3 -28
  70. warp/tests/test_dlpack.py +170 -0
  71. warp/tests/test_examples.py +22 -8
  72. warp/tests/test_fast_math.py +10 -4
  73. warp/tests/test_fem.py +64 -0
  74. warp/tests/test_func.py +46 -0
  75. warp/tests/test_implicit_init.py +49 -0
  76. warp/tests/test_jax.py +58 -0
  77. warp/tests/test_mat.py +84 -0
  78. warp/tests/test_mesh_query_point.py +188 -0
  79. warp/tests/test_module_hashing.py +40 -0
  80. warp/tests/test_multigpu.py +3 -3
  81. warp/tests/test_overwrite.py +8 -0
  82. warp/tests/test_paddle.py +852 -0
  83. warp/tests/test_print.py +89 -0
  84. warp/tests/test_quat.py +111 -0
  85. warp/tests/test_reload.py +31 -1
  86. warp/tests/test_scalar_ops.py +2 -0
  87. warp/tests/test_static.py +412 -0
  88. warp/tests/test_streams.py +64 -3
  89. warp/tests/test_struct.py +4 -4
  90. warp/tests/test_torch.py +24 -0
  91. warp/tests/test_triangle_closest_point.py +137 -0
  92. warp/tests/test_types.py +1 -1
  93. warp/tests/test_vbd.py +386 -0
  94. warp/tests/test_vec.py +143 -0
  95. warp/tests/test_vec_scalar_ops.py +139 -0
  96. warp/tests/test_volume.py +30 -0
  97. warp/tests/unittest_suites.py +12 -0
  98. warp/tests/unittest_utils.py +9 -5
  99. warp/thirdparty/dlpack.py +3 -1
  100. warp/types.py +157 -34
  101. warp/utils.py +37 -14
  102. {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/METADATA +10 -8
  103. {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/RECORD +106 -94
  104. warp/tests/test_point_triangle_closest_point.py +0 -143
  105. {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/LICENSE.md +0 -0
  106. {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/WHEEL +0 -0
  107. {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/top_level.txt +0 -0
warp/stubs.py CHANGED
@@ -108,11 +108,17 @@ from warp.jax import device_from_jax, device_to_jax
108
108
 
109
109
  from warp.dlpack import from_dlpack, to_dlpack
110
110
 
111
+ from warp.paddle import from_paddle, to_paddle
112
+ from warp.paddle import dtype_from_paddle, dtype_to_paddle
113
+ from warp.paddle import device_from_paddle, device_to_paddle
114
+ from warp.paddle import stream_from_paddle
115
+
111
116
  from warp.build import clear_kernel_cache
112
117
 
113
118
  from warp.constants import *
114
119
 
115
120
  from . import builtins
121
+ from warp.builtins import static
116
122
 
117
123
  import warp.config as config
118
124
 
@@ -602,6 +608,82 @@ def cw_div(a: Matrix[Any, Any, Scalar], b: Matrix[Any, Any, Scalar]) -> Matrix[A
602
608
  ...
603
609
 
604
610
 
611
+ @over
612
+ def vector(*args: Scalar, length: int32, dtype: Scalar) -> Vector[Any, Scalar]:
613
+ """Construct a vector of given length and dtype."""
614
+ ...
615
+
616
+
617
+ @over
618
+ def matrix(pos: Vector[3, Float], rot: Quaternion[Float], scale: Vector[3, Float], dtype: Float) -> Matrix[4, 4, Float]:
619
+ """Construct a 4x4 transformation matrix that applies the transformations as
620
+ Translation(pos)*Rotation(rot)*Scaling(scale) when applied to column vectors, i.e.: y = (TRS)*x
621
+ """
622
+ ...
623
+
624
+
625
+ @over
626
+ def matrix(*args: Scalar, shape: Tuple[int, int], dtype: Scalar) -> Matrix[Any, Any, Scalar]:
627
+ """Construct a matrix. If the positional ``arg_types`` are not given, then matrix will be zero-initialized."""
628
+ ...
629
+
630
+
631
+ @over
632
+ def identity(n: int32, dtype: Scalar) -> Matrix[Any, Any, Scalar]:
633
+ """Create an identity matrix with shape=(n,n) with the type given by ``dtype``."""
634
+ ...
635
+
636
+
637
+ @over
638
+ def svd3(A: Matrix[3, 3, Float], U: Matrix[3, 3, Float], sigma: Vector[3, Float], V: Matrix[3, 3, Scalar]):
639
+ """Compute the SVD of a 3x3 matrix ``A``. The singular values are returned in ``sigma``,
640
+ while the left and right basis vectors are returned in ``U`` and ``V``.
641
+ """
642
+ ...
643
+
644
+
645
+ @over
646
+ def qr3(A: Matrix[3, 3, Float], Q: Matrix[3, 3, Float], R: Matrix[3, 3, Float]):
647
+ """Compute the QR decomposition of a 3x3 matrix ``A``. The orthogonal matrix is returned in ``Q``,
648
+ while the upper triangular matrix is returned in ``R``.
649
+ """
650
+ ...
651
+
652
+
653
+ @over
654
+ def eig3(A: Matrix[3, 3, Float], Q: Matrix[3, 3, Float], d: Vector[3, Float]):
655
+ """Compute the eigendecomposition of a 3x3 matrix ``A``. The eigenvectors are returned as the columns of ``Q``,
656
+ while the corresponding eigenvalues are returned in ``d``.
657
+ """
658
+ ...
659
+
660
+
661
+ @over
662
+ def quaternion(dtype: Float) -> Quaternion[Float]:
663
+ """Construct a zero-initialized quaternion. Quaternions are laid out as
664
+ [ix, iy, iz, r], where ix, iy, iz are the imaginary part, and r the real part.
665
+ """
666
+ ...
667
+
668
+
669
+ @over
670
+ def quaternion(x: Float, y: Float, z: Float, w: Float) -> Quaternion[Float]:
671
+ """Create a quaternion using the supplied components (type inferred from component type)."""
672
+ ...
673
+
674
+
675
+ @over
676
+ def quaternion(ijk: Vector[3, Float], real: Float, dtype: Float) -> Quaternion[Float]:
677
+ """Create a quaternion using the supplied vector/scalar (type inferred from scalar type)."""
678
+ ...
679
+
680
+
681
+ @over
682
+ def quaternion(quat: Quaternion[Float], dtype: Float) -> Quaternion[Float]:
683
+ """Construct a quaternion of type dtype from another quaternion of a different dtype."""
684
+ ...
685
+
686
+
605
687
  @over
606
688
  def quat_identity(dtype: Float) -> quatf:
607
689
  """Construct an identity quaternion with zero imaginary part and real part of 1.0"""
@@ -662,6 +744,12 @@ def quat_to_matrix(quat: Quaternion[Float]) -> Matrix[3, 3, Float]:
662
744
  ...
663
745
 
664
746
 
747
+ @over
748
+ def transformation(pos: Vector[3, Float], rot: Quaternion[Float], dtype: Float) -> Transformation[Float]:
749
+ """Construct a rigid-body transformation with translation part ``pos`` and rotation ``rot``."""
750
+ ...
751
+
752
+
665
753
  @over
666
754
  def transform_identity(dtype: Float) -> transformf:
667
755
  """Construct an identity transform with zero translation and identity rotation."""
@@ -728,6 +816,30 @@ def transform_inverse(xform: Transformation[Float]) -> Transformation[Float]:
728
816
  ...
729
817
 
730
818
 
819
+ @over
820
+ def spatial_vector(dtype: Float) -> Vector[6, Float]:
821
+ """Zero-initialize a 6D screw vector."""
822
+ ...
823
+
824
+
825
+ @over
826
+ def spatial_vector(w: Vector[3, Float], v: Vector[3, Float], dtype: Float) -> Vector[6, Float]:
827
+ """Construct a 6D screw vector from two 3D vectors."""
828
+ ...
829
+
830
+
831
+ @over
832
+ def spatial_vector(wx: Float, wy: Float, wz: Float, vx: Float, vy: Float, vz: Float, dtype: Float) -> Vector[6, Float]:
833
+ """Construct a 6D screw vector from six values."""
834
+ ...
835
+
836
+
837
+ @over
838
+ def spatial_adjoint(r: Matrix[3, 3, Float], s: Matrix[3, 3, Float]) -> Matrix[6, 6, Float]:
839
+ """Construct a 6x6 spatial inertial matrix from two 3x3 diagonal blocks."""
840
+ ...
841
+
842
+
731
843
  @over
732
844
  def spatial_dot(a: Vector[6, Float], b: Vector[6, Float]) -> Float:
733
845
  """Compute the dot product of two 6D screw vectors."""
@@ -805,6 +917,282 @@ def mlp(
805
917
  ...
806
918
 
807
919
 
920
+ @over
921
+ def bvh_query_aabb(id: uint64, low: vec3f, high: vec3f) -> bvh_query_t:
922
+ """Construct an axis-aligned bounding box query against a BVH object.
923
+
924
+ This query can be used to iterate over all bounds inside a BVH.
925
+
926
+ :param id: The BVH identifier
927
+ :param low: The lower bound of the bounding box in BVH space
928
+ :param high: The upper bound of the bounding box in BVH space
929
+ """
930
+ ...
931
+
932
+
933
+ @over
934
+ def bvh_query_ray(id: uint64, start: vec3f, dir: vec3f) -> bvh_query_t:
935
+ """Construct a ray query against a BVH object.
936
+
937
+ This query can be used to iterate over all bounds that intersect the ray.
938
+
939
+ :param id: The BVH identifier
940
+ :param start: The start of the ray in BVH space
941
+ :param dir: The direction of the ray in BVH space
942
+ """
943
+ ...
944
+
945
+
946
+ @over
947
+ def bvh_query_next(query: bvh_query_t, index: int32) -> bool:
948
+ """Move to the next bound returned by the query.
949
+ The index of the current bound is stored in ``index``, returns ``False`` if there are no more overlapping bound.
950
+ """
951
+ ...
952
+
953
+
954
+ @over
955
+ def mesh_query_point(id: uint64, point: vec3f, max_dist: float32) -> mesh_query_point_t:
956
+ """Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
957
+
958
+ Identifies the sign of the distance using additional ray-casts to determine if the point is inside or outside.
959
+ This method is relatively robust, but does increase computational cost.
960
+ See below for additional sign determination methods.
961
+
962
+ :param id: The mesh identifier
963
+ :param point: The point in space to query
964
+ :param max_dist: Mesh faces above this distance will not be considered by the query
965
+ """
966
+ ...
967
+
968
+
969
+ @over
970
+ def mesh_query_point_no_sign(id: uint64, point: vec3f, max_dist: float32) -> mesh_query_point_t:
971
+ """Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
972
+
973
+ This method does not compute the sign of the point (inside/outside) which makes it faster than other point query methods.
974
+
975
+ :param id: The mesh identifier
976
+ :param point: The point in space to query
977
+ :param max_dist: Mesh faces above this distance will not be considered by the query
978
+ """
979
+ ...
980
+
981
+
982
+ @over
983
+ def mesh_query_furthest_point_no_sign(id: uint64, point: vec3f, min_dist: float32) -> mesh_query_point_t:
984
+ """Computes the furthest point on the mesh with identifier `id` to the given point in space.
985
+
986
+ This method does not compute the sign of the point (inside/outside).
987
+
988
+ :param id: The mesh identifier
989
+ :param point: The point in space to query
990
+ :param min_dist: Mesh faces below this distance will not be considered by the query
991
+ """
992
+ ...
993
+
994
+
995
+ @over
996
+ def mesh_query_point_sign_normal(id: uint64, point: vec3f, max_dist: float32, epsilon: float32) -> mesh_query_point_t:
997
+ """Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given ``point`` in space.
998
+
999
+ Identifies the sign of the distance (inside/outside) using the angle-weighted pseudo normal.
1000
+ This approach to sign determination is robust for well conditioned meshes that are watertight and non-self intersecting.
1001
+ It is also comparatively fast to compute.
1002
+
1003
+ :param id: The mesh identifier
1004
+ :param point: The point in space to query
1005
+ :param max_dist: Mesh faces above this distance will not be considered by the query
1006
+ :param epsilon: Epsilon treating distance values as equal, when locating the minimum distance vertex/face/edge, as a
1007
+ fraction of the average edge length, also for treating closest point as being on edge/vertex default 1e-3
1008
+ """
1009
+ ...
1010
+
1011
+
1012
+ @over
1013
+ def mesh_query_point_sign_winding_number(
1014
+ id: uint64, point: vec3f, max_dist: float32, accuracy: float32, threshold: float32
1015
+ ) -> mesh_query_point_t:
1016
+ """Computes the closest point on the :class:`Mesh` with identifier ``id`` to the given point in space.
1017
+
1018
+ Identifies the sign using the winding number of the mesh relative to the query point. This method of sign determination is robust for poorly conditioned meshes
1019
+ and provides a smooth approximation to sign even when the mesh is not watertight. This method is the most robust and accurate of the sign determination meshes
1020
+ but also the most expensive.
1021
+
1022
+ .. note:: The :class:`Mesh` object must be constructed with ``support_winding_number=True`` for this method to return correct results.
1023
+
1024
+ :param id: The mesh identifier
1025
+ :param point: The point in space to query
1026
+ :param max_dist: Mesh faces above this distance will not be considered by the query
1027
+ :param accuracy: Accuracy for computing the winding number with fast winding number method utilizing second-order dipole approximation, default 2.0
1028
+ :param threshold: The threshold of the winding number to be considered inside, default 0.5
1029
+ """
1030
+ ...
1031
+
1032
+
1033
+ @over
1034
+ def mesh_query_ray(id: uint64, start: vec3f, dir: vec3f, max_t: float32) -> mesh_query_ray_t:
1035
+ """Computes the closest ray hit on the :class:`Mesh` with identifier ``id``.
1036
+
1037
+ :param id: The mesh identifier
1038
+ :param start: The start point of the ray
1039
+ :param dir: The ray direction (should be normalized)
1040
+ :param max_t: The maximum distance along the ray to check for intersections
1041
+ """
1042
+ ...
1043
+
1044
+
1045
+ @over
1046
+ def mesh_query_aabb(id: uint64, low: vec3f, high: vec3f) -> mesh_query_aabb_t:
1047
+ """Construct an axis-aligned bounding box query against a :class:`Mesh`.
1048
+
1049
+ This query can be used to iterate over all triangles inside a volume.
1050
+
1051
+ :param id: The mesh identifier
1052
+ :param low: The lower bound of the bounding box in mesh space
1053
+ :param high: The upper bound of the bounding box in mesh space
1054
+ """
1055
+ ...
1056
+
1057
+
1058
+ @over
1059
+ def mesh_query_aabb_next(query: mesh_query_aabb_t, index: int32) -> bool:
1060
+ """Move to the next triangle overlapping the query bounding box.
1061
+
1062
+ The index of the current face is stored in ``index``, returns ``False`` if there are no more overlapping triangles.
1063
+ """
1064
+ ...
1065
+
1066
+
1067
+ @over
1068
+ def mesh_eval_position(id: uint64, face: int32, bary_u: float32, bary_v: float32) -> vec3f:
1069
+ """Evaluates the position on the :class:`Mesh` given a face index and barycentric coordinates."""
1070
+ ...
1071
+
1072
+
1073
+ @over
1074
+ def mesh_eval_velocity(id: uint64, face: int32, bary_u: float32, bary_v: float32) -> vec3f:
1075
+ """Evaluates the velocity on the :class:`Mesh` given a face index and barycentric coordinates."""
1076
+ ...
1077
+
1078
+
1079
+ @over
1080
+ def hash_grid_query(id: uint64, point: vec3f, max_dist: float32) -> hash_grid_query_t:
1081
+ """Construct a point query against a :class:`HashGrid`.
1082
+
1083
+ This query can be used to iterate over all neighboring point within a fixed radius from the query point.
1084
+ """
1085
+ ...
1086
+
1087
+
1088
+ @over
1089
+ def hash_grid_query_next(query: hash_grid_query_t, index: int32) -> bool:
1090
+ """Move to the next point in the hash grid query.
1091
+
1092
+ The index of the current neighbor is stored in ``index``, returns ``False`` if there are no more neighbors.
1093
+ """
1094
+ ...
1095
+
1096
+
1097
+ @over
1098
+ def hash_grid_point_id(id: uint64, index: int32) -> int:
1099
+ """Return the index of a point in the :class:`HashGrid`.
1100
+
1101
+ This can be used to reorder threads such that grid traversal occurs in a spatially coherent order.
1102
+
1103
+ Returns -1 if the :class:`HashGrid` has not been reserved.
1104
+ """
1105
+ ...
1106
+
1107
+
1108
+ @over
1109
+ def intersect_tri_tri(v0: vec3f, v1: vec3f, v2: vec3f, u0: vec3f, u1: vec3f, u2: vec3f) -> int:
1110
+ """Tests for intersection between two triangles (v0, v1, v2) and (u0, u1, u2) using Moller's method.
1111
+
1112
+ Returns > 0 if triangles intersect.
1113
+ """
1114
+ ...
1115
+
1116
+
1117
+ @over
1118
+ def mesh_get(id: uint64) -> Mesh:
1119
+ """Retrieves the mesh given its index."""
1120
+ ...
1121
+
1122
+
1123
+ @over
1124
+ def mesh_eval_face_normal(id: uint64, face: int32) -> vec3f:
1125
+ """Evaluates the face normal the mesh given a face index."""
1126
+ ...
1127
+
1128
+
1129
+ @over
1130
+ def mesh_get_point(id: uint64, index: int32) -> vec3f:
1131
+ """Returns the point of the mesh given a index."""
1132
+ ...
1133
+
1134
+
1135
+ @over
1136
+ def mesh_get_velocity(id: uint64, index: int32) -> vec3f:
1137
+ """Returns the velocity of the mesh given a index."""
1138
+ ...
1139
+
1140
+
1141
+ @over
1142
+ def mesh_get_index(id: uint64, index: int32) -> int:
1143
+ """Returns the point-index of the mesh given a face-vertex index."""
1144
+ ...
1145
+
1146
+
1147
+ @over
1148
+ def closest_point_edge_edge(p1: vec3f, q1: vec3f, p2: vec3f, q2: vec3f, epsilon: float32) -> vec3f:
1149
+ """Finds the closest points between two edges.
1150
+
1151
+ Returns barycentric weights to the points on each edge, as well as the closest distance between the edges.
1152
+
1153
+ :param p1: First point of first edge
1154
+ :param q1: Second point of first edge
1155
+ :param p2: First point of second edge
1156
+ :param q2: Second point of second edge
1157
+ :param epsilon: Zero tolerance for determining if points in an edge are degenerate.
1158
+ :param out: vec3 output containing (s,t,d), where `s` in [0,1] is the barycentric weight for the first edge, `t` is the barycentric weight for the second edge, and `d` is the distance between the two edges at these two closest points.
1159
+ """
1160
+ ...
1161
+
1162
+
1163
+ @over
1164
+ def volume_sample(id: uint64, uvw: vec3f, sampling_mode: int32, dtype: Any) -> Any:
1165
+ """Sample the volume of type `dtype` given by ``id`` at the volume local-space point ``uvw``.
1166
+
1167
+ Interpolation should be :attr:`warp.Volume.CLOSEST` or :attr:`wp.Volume.LINEAR.`
1168
+ """
1169
+ ...
1170
+
1171
+
1172
+ @over
1173
+ def volume_sample_grad(id: uint64, uvw: vec3f, sampling_mode: int32, grad: Any, dtype: Any) -> Any:
1174
+ """Sample the volume given by ``id`` and its gradient at the volume local-space point ``uvw``.
1175
+
1176
+ Interpolation should be :attr:`warp.Volume.CLOSEST` or :attr:`wp.Volume.LINEAR.`
1177
+ """
1178
+ ...
1179
+
1180
+
1181
+ @over
1182
+ def volume_lookup(id: uint64, i: int32, j: int32, k: int32, dtype: Any) -> Any:
1183
+ """Returns the value of voxel with coordinates ``i``, ``j``, ``k`` for a volume of type type `dtype`.
1184
+
1185
+ If the voxel at this index does not exist, this function returns the background value.
1186
+ """
1187
+ ...
1188
+
1189
+
1190
+ @over
1191
+ def volume_store(id: uint64, i: int32, j: int32, k: int32, value: Any):
1192
+ """Store ``value`` at the voxel with coordinates ``i``, ``j``, ``k``."""
1193
+ ...
1194
+
1195
+
808
1196
  @over
809
1197
  def volume_sample_f(id: uint64, uvw: vec3f, sampling_mode: int32) -> float:
810
1198
  """Sample the volume given by ``id`` at the volume local-space point ``uvw``.
@@ -883,6 +1271,32 @@ def volume_store_i(id: uint64, i: int32, j: int32, k: int32, value: int32):
883
1271
  ...
884
1272
 
885
1273
 
1274
+ @over
1275
+ def volume_sample_index(id: uint64, uvw: vec3f, sampling_mode: int32, voxel_data: Array[Any], background: Any) -> Any:
1276
+ """Sample the volume given by ``id`` at the volume local-space point ``uvw``.
1277
+
1278
+ Values for allocated voxels are read from the ``voxel_data`` array, and `background` is used as the value of non-existing voxels.
1279
+ Interpolation should be :attr:`warp.Volume.CLOSEST` or :attr:`wp.Volume.LINEAR`.
1280
+ This function is available for both index grids and classical volumes.
1281
+
1282
+ """
1283
+ ...
1284
+
1285
+
1286
+ @over
1287
+ def volume_sample_grad_index(
1288
+ id: uint64, uvw: vec3f, sampling_mode: int32, voxel_data: Array[Any], background: Any, grad: Any
1289
+ ) -> Any:
1290
+ """Sample the volume given by ``id`` and its gradient at the volume local-space point ``uvw``.
1291
+
1292
+ Values for allocated voxels are read from the ``voxel_data`` array, and `background` is used as the value of non-existing voxels.
1293
+ Interpolation should be :attr:`warp.Volume.CLOSEST` or :attr:`wp.Volume.LINEAR`.
1294
+ This function is available for both index grids and classical volumes.
1295
+
1296
+ """
1297
+ ...
1298
+
1299
+
886
1300
  @over
887
1301
  def volume_lookup_index(id: uint64, i: int32, j: int32, k: int32) -> int32:
888
1302
  """Returns the index associated to the voxel with coordinates ``i``, ``j``, ``k``.
@@ -1106,6 +1520,30 @@ def printf(fmt: str, *args: Any):
1106
1520
  ...
1107
1521
 
1108
1522
 
1523
+ @over
1524
+ def print(value: Any):
1525
+ """Print variable to stdout"""
1526
+ ...
1527
+
1528
+
1529
+ @over
1530
+ def breakpoint():
1531
+ """Debugger breakpoint"""
1532
+ ...
1533
+
1534
+
1535
+ @over
1536
+ def tid() -> int:
1537
+ """Return the current thread index for a 1D kernel launch.
1538
+
1539
+ Note that this is the *global* index of the thread in the range [0, dim)
1540
+ where dim is the parameter passed to kernel launch.
1541
+
1542
+ This function may not be called from user-defined Warp functions.
1543
+ """
1544
+ ...
1545
+
1546
+
1109
1547
  @over
1110
1548
  def tid() -> Tuple[int, int]:
1111
1549
  """Return the current thread indices for a 2D kernel launch.
@@ -1807,6 +2245,12 @@ def mod(a: Scalar, b: Scalar) -> Scalar:
1807
2245
  ...
1808
2246
 
1809
2247
 
2248
+ @over
2249
+ def mod(a: Vector[Any, Scalar], b: Vector[Any, Scalar]) -> Scalar:
2250
+ """Modulo operation using truncated division."""
2251
+ ...
2252
+
2253
+
1810
2254
  @over
1811
2255
  def div(a: Scalar, b: Scalar) -> Scalar:
1812
2256
  """ """
@@ -1961,3 +2405,18 @@ def unot(a: uint64) -> bool:
1961
2405
  def unot(a: Array[Any]) -> bool:
1962
2406
  """ """
1963
2407
  ...
2408
+
2409
+
2410
+ @over
2411
+ def static(expr: Any) -> Any:
2412
+ """Evaluates a static Python expression and replaces it with its result.
2413
+
2414
+ See the `codegen.html#static-expressions <section on code generation>`_ for more details.
2415
+
2416
+ Note:
2417
+ The inner expression must only reference variables that are available from the current scope where the Warp kernel or function containing the expression is defined,
2418
+ which includes constant variables and variables captured in the current closure in which the function or kernel is implemented.
2419
+ The return type of the expression must be either a Warp function, a string, or a type that is supported inside Warp kernels and functions
2420
+ (excluding Warp arrays since they cannot be created in a Warp kernel at the moment).
2421
+ """
2422
+ ...
warp/tape.py CHANGED
@@ -50,6 +50,8 @@ class Tape:
50
50
  self.loss = None
51
51
 
52
52
  def __enter__(self):
53
+ wp.context.init()
54
+
53
55
  if wp.context.runtime.tape is not None:
54
56
  raise RuntimeError("Warp: Error, entering a tape while one is already active")
55
57
 
@@ -18,3 +18,4 @@ def kern(expect: float):
18
18
 
19
19
  def run(expect, device):
20
20
  wp.launch(kern, dim=1, inputs=[expect], device=device)
21
+ wp.synchronize_device(device)
@@ -0,0 +1,32 @@
1
+ # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import warp as wp
9
+
10
+
11
+ # test clashes with identical struct from another module
12
+ @wp.struct
13
+ class SameStruct:
14
+ x: float
15
+
16
+
17
+ # test clashes with identically named but different struct from another module
18
+ @wp.struct
19
+ class DifferentStruct:
20
+ v: float
21
+
22
+
23
+ # test clashes with identical function from another module
24
+ @wp.func
25
+ def same_func():
26
+ return 99
27
+
28
+
29
+ # test clashes with identically named but different function from another module
30
+ @wp.func
31
+ def different_func():
32
+ return 17
@@ -0,0 +1,32 @@
1
+ # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import warp as wp
9
+
10
+
11
+ # test clashes with identical struct from another module
12
+ @wp.struct
13
+ class SameStruct:
14
+ x: float
15
+
16
+
17
+ # test clashes with identically named but different struct from another module
18
+ @wp.struct
19
+ class DifferentStruct:
20
+ v: wp.vec2
21
+
22
+
23
+ # test clashes with identical function from another module
24
+ @wp.func
25
+ def same_func():
26
+ return 99
27
+
28
+
29
+ # test clashes with identically named but different function from another module
30
+ @wp.func
31
+ def different_func():
32
+ return 42
@@ -13,3 +13,4 @@ def kern(expect: float):
13
13
 
14
14
  def run(expect, device):
15
15
  wp.launch(kern, dim=1, inputs=[expect], device=device)
16
+ wp.synchronize_device(device)