warp-lang 1.5.0__py3-none-macosx_10_13_universal2.whl → 1.6.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (132) hide show
  1. warp/__init__.py +5 -0
  2. warp/autograd.py +414 -191
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +40 -12
  6. warp/build_dll.py +13 -6
  7. warp/builtins.py +1124 -497
  8. warp/codegen.py +261 -136
  9. warp/config.py +1 -1
  10. warp/context.py +357 -119
  11. warp/examples/assets/square_cloth.usd +0 -0
  12. warp/examples/benchmarks/benchmark_gemm.py +27 -18
  13. warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
  14. warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
  15. warp/examples/core/example_torch.py +18 -34
  16. warp/examples/fem/example_apic_fluid.py +1 -0
  17. warp/examples/fem/example_mixed_elasticity.py +1 -1
  18. warp/examples/optim/example_bounce.py +1 -1
  19. warp/examples/optim/example_cloth_throw.py +1 -1
  20. warp/examples/optim/example_diffray.py +4 -15
  21. warp/examples/optim/example_drone.py +1 -1
  22. warp/examples/optim/example_softbody_properties.py +392 -0
  23. warp/examples/optim/example_trajectory.py +1 -3
  24. warp/examples/optim/example_walker.py +5 -0
  25. warp/examples/sim/example_cartpole.py +0 -2
  26. warp/examples/sim/example_cloth.py +3 -1
  27. warp/examples/sim/example_cloth_self_contact.py +260 -0
  28. warp/examples/sim/example_granular_collision_sdf.py +4 -5
  29. warp/examples/sim/example_jacobian_ik.py +0 -2
  30. warp/examples/sim/example_quadruped.py +5 -2
  31. warp/examples/tile/example_tile_cholesky.py +79 -0
  32. warp/examples/tile/example_tile_convolution.py +2 -2
  33. warp/examples/tile/example_tile_fft.py +2 -2
  34. warp/examples/tile/example_tile_filtering.py +3 -3
  35. warp/examples/tile/example_tile_matmul.py +4 -4
  36. warp/examples/tile/example_tile_mlp.py +12 -12
  37. warp/examples/tile/example_tile_nbody.py +180 -0
  38. warp/examples/tile/example_tile_walker.py +319 -0
  39. warp/fem/geometry/geometry.py +0 -2
  40. warp/math.py +147 -0
  41. warp/native/array.h +12 -0
  42. warp/native/builtin.h +0 -1
  43. warp/native/bvh.cpp +149 -70
  44. warp/native/bvh.cu +287 -68
  45. warp/native/bvh.h +195 -85
  46. warp/native/clang/clang.cpp +5 -1
  47. warp/native/coloring.cpp +5 -1
  48. warp/native/cuda_util.cpp +91 -53
  49. warp/native/cuda_util.h +5 -0
  50. warp/native/exports.h +40 -40
  51. warp/native/intersect.h +17 -0
  52. warp/native/mat.h +41 -0
  53. warp/native/mathdx.cpp +19 -0
  54. warp/native/mesh.cpp +25 -8
  55. warp/native/mesh.cu +153 -101
  56. warp/native/mesh.h +482 -403
  57. warp/native/quat.h +40 -0
  58. warp/native/solid_angle.h +7 -0
  59. warp/native/sort.cpp +85 -0
  60. warp/native/sort.cu +34 -0
  61. warp/native/sort.h +3 -1
  62. warp/native/spatial.h +11 -0
  63. warp/native/tile.h +1187 -669
  64. warp/native/tile_reduce.h +8 -6
  65. warp/native/vec.h +41 -0
  66. warp/native/warp.cpp +8 -1
  67. warp/native/warp.cu +263 -40
  68. warp/native/warp.h +19 -5
  69. warp/optim/linear.py +22 -4
  70. warp/render/render_opengl.py +130 -64
  71. warp/sim/__init__.py +6 -1
  72. warp/sim/collide.py +270 -26
  73. warp/sim/import_urdf.py +8 -8
  74. warp/sim/integrator_euler.py +25 -7
  75. warp/sim/integrator_featherstone.py +154 -35
  76. warp/sim/integrator_vbd.py +842 -40
  77. warp/sim/model.py +134 -72
  78. warp/sparse.py +1 -1
  79. warp/stubs.py +265 -132
  80. warp/tape.py +28 -30
  81. warp/tests/aux_test_module_unload.py +15 -0
  82. warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
  83. warp/tests/test_array.py +74 -0
  84. warp/tests/test_assert.py +242 -0
  85. warp/tests/test_codegen.py +14 -61
  86. warp/tests/test_collision.py +2 -2
  87. warp/tests/test_coloring.py +12 -2
  88. warp/tests/test_examples.py +12 -1
  89. warp/tests/test_func.py +21 -4
  90. warp/tests/test_grad_debug.py +87 -2
  91. warp/tests/test_hash_grid.py +1 -1
  92. warp/tests/test_ipc.py +116 -0
  93. warp/tests/test_lerp.py +13 -87
  94. warp/tests/test_mat.py +138 -167
  95. warp/tests/test_math.py +47 -1
  96. warp/tests/test_matmul.py +17 -16
  97. warp/tests/test_matmul_lite.py +10 -15
  98. warp/tests/test_mesh.py +84 -60
  99. warp/tests/test_mesh_query_aabb.py +165 -0
  100. warp/tests/test_mesh_query_point.py +328 -286
  101. warp/tests/test_mesh_query_ray.py +134 -121
  102. warp/tests/test_mlp.py +2 -2
  103. warp/tests/test_operators.py +43 -0
  104. warp/tests/test_overwrite.py +47 -2
  105. warp/tests/test_quat.py +77 -0
  106. warp/tests/test_reload.py +29 -0
  107. warp/tests/test_sim_grad_bounce_linear.py +204 -0
  108. warp/tests/test_smoothstep.py +17 -83
  109. warp/tests/test_static.py +19 -3
  110. warp/tests/test_tape.py +25 -0
  111. warp/tests/test_tile.py +178 -191
  112. warp/tests/test_tile_load.py +356 -0
  113. warp/tests/test_tile_mathdx.py +61 -8
  114. warp/tests/test_tile_mlp.py +17 -17
  115. warp/tests/test_tile_reduce.py +24 -18
  116. warp/tests/test_tile_shared_memory.py +66 -17
  117. warp/tests/test_tile_view.py +165 -0
  118. warp/tests/test_torch.py +35 -0
  119. warp/tests/test_utils.py +36 -24
  120. warp/tests/test_vec.py +110 -0
  121. warp/tests/unittest_suites.py +29 -4
  122. warp/tests/unittest_utils.py +30 -13
  123. warp/thirdparty/unittest_parallel.py +2 -2
  124. warp/types.py +411 -101
  125. warp/utils.py +10 -7
  126. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/METADATA +92 -69
  127. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/RECORD +130 -119
  128. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/WHEEL +1 -1
  129. warp/examples/benchmarks/benchmark_tile.py +0 -179
  130. warp/native/tile_gemm.h +0 -341
  131. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/LICENSE.md +0 -0
  132. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/top_level.txt +0 -0
warp/stubs.py CHANGED
@@ -48,6 +48,8 @@ from warp.types import matrix as mat
48
48
 
49
49
  from warp.types import dtype_from_numpy, dtype_to_numpy
50
50
 
51
+ from warp.types import from_ipc_handle
52
+
51
53
  from warp.context import init, func, func_grad, func_replay, func_native, kernel, struct, overload
52
54
  from warp.context import is_cpu_available, is_cuda_available, is_device_available
53
55
  from warp.context import get_devices, get_preferred_device
@@ -70,6 +72,7 @@ from warp.context import (
70
72
  synchronize,
71
73
  force_load,
72
74
  load_module,
75
+ event_from_ipc_handle,
73
76
  )
74
77
  from warp.context import set_module_options, get_module_options, get_module
75
78
  from warp.context import capture_begin, capture_end, capture_launch
@@ -120,6 +123,8 @@ from warp.constants import *
120
123
  from . import builtins
121
124
  from warp.builtins import static
122
125
 
126
+ from warp.math import *
127
+
123
128
  import warp.config as config
124
129
 
125
130
  __version__ = config.version
@@ -895,36 +900,34 @@ def spatial_mass(
895
900
 
896
901
 
897
902
  @over
898
- def tile_zeros(m: int32, n: int32, dtype: Any, storage: str) -> Tile:
899
- """Allocates a tile of zero-initialized items.
903
+ def tile_zeros(shape: Tuple[int, ...], dtype: Any, storage: str) -> Tile:
904
+ """Allocate a tile of zero-initialized items.
900
905
 
901
- :param m: Size of the first dimension of the output tile
902
- :param n: Size of the second dimension of the output tile
903
- :param dtype: Datatype of output tile's elements
906
+ :param shape: Shape of the output tile
907
+ :param dtype: Data type of output tile's elements (default float)
904
908
  :param storage: The storage location for the tile: ``"register"`` for registers
905
909
  (default) or ``"shared"`` for shared memory.
906
- :returns: A zero-initialized tile with ``shape=(m,n)`` and the specified datatype
910
+ :returns: A zero-initialized tile with shape and data type as specified
907
911
  """
908
912
  ...
909
913
 
910
914
 
911
915
  @over
912
- def tile_ones(m: int32, n: int32, dtype: Any, storage: str) -> Tile:
913
- """Allocates a tile of one-initialized items.
916
+ def tile_ones(shape: Tuple[int, ...], dtype: Any, storage: str) -> Tile:
917
+ """Allocate a tile of one-initialized items.
914
918
 
915
- :param m: Size of the first dimension of the output tile
916
- :param n: Size of the second dimension of the output tile
917
- :param dtype: Datatype of output tile's elements
919
+ :param shape: Shape of the output tile
920
+ :param dtype: Data type of output tile's elements
918
921
  :param storage: The storage location for the tile: ``"register"`` for registers
919
922
  (default) or ``"shared"`` for shared memory.
920
- :returns: A one-initialized tile with ``shape=(m,n)`` and the specified dtype
923
+ :returns: A one-initialized tile with shape and data type as specified
921
924
  """
922
925
  ...
923
926
 
924
927
 
925
928
  @over
926
929
  def tile_arange(*args: Scalar, dtype: Any, storage: str) -> Tile:
927
- """Generates a tile of linearly spaced elements.
930
+ """Generate a tile of linearly spaced elements.
928
931
 
929
932
  :param args: Variable-length positional arguments, interpreted as:
930
933
 
@@ -932,124 +935,88 @@ def tile_arange(*args: Scalar, dtype: Any, storage: str) -> Tile:
932
935
  - ``(start, stop)``: Generates values from ``start`` to ``stop - 1``
933
936
  - ``(start, stop, step)``: Generates values from ``start`` to ``stop - 1`` with a step size
934
937
 
935
- :param dtype: Datatype of output tile's elements (optional, default: int)
938
+ :param dtype: Data type of output tile's elements (optional, default: ``float``)
936
939
  :param storage: The storage location for the tile: ``"register"`` for registers
937
940
  (default) or ``"shared"`` for shared memory.
938
- :returns: A tile with ``shape=(1,n)`` with linearly spaced elements of specified dtype
941
+ :returns: A tile with ``shape=(n)`` with linearly spaced elements of specified data type
939
942
  """
940
943
  ...
941
944
 
942
945
 
943
946
  @over
944
- def tile_load(a: Array[Any], i: int32, n: int32, storage: str) -> Tile:
945
- """Loads a 1D tile from a global memory array.
947
+ def tile_load(a: Array[Any], shape: Tuple[int, ...], offset: Tuple[int, ...], storage: str):
948
+ """Loads a tile from a global memory array.
946
949
 
947
950
  This method will cooperatively load a tile from global memory using all threads in the block.
948
951
 
949
952
  :param a: The source array in global memory
950
- :param i: Offset in the source array measured in multiples of ``n``, i.e.: ``offset=i*n``
951
- :param n: The number of elements in the tile
953
+ :param shape: Shape of the tile to load, must have the same number of dimensions as ``a``
954
+ :param offset: Offset in the source array to begin reading from (optional)
952
955
  :param storage: The storage location for the tile: ``"register"`` for registers
953
956
  (default) or ``"shared"`` for shared memory.
954
- :returns: A tile with ``shape=(1,n)`` and dtype the same as the source array
955
- """
956
- ...
957
-
958
-
959
- @over
960
- def tile_load(a: Array[Any], i: int32, j: int32, m: int32, n: int32, storage: str) -> Tile:
961
- """Loads a 2D tile from a global memory array.
962
-
963
- This method will cooperatively load a tile from global memory using all threads in the block.
964
-
965
- :param a: The source array in global memory
966
- :param i: Offset in the source array measured in multiples of ``m``, i.e.: ``row=i*m``
967
- :param j: Offset in the source array measured in multiples of ``n``, i.e.; ``col=j*n``
968
- :param m: The size of the tile's first dimension
969
- :param n: The size of the tile's second dimension
970
- :param storage: The storage location for the tile: ``"register"`` for registers
971
- (default) or ``"shared"`` for shared memory.
972
- :returns: A tile with ``shape=(m,n)`` and dtype the same as the source array
973
- """
974
- ...
975
-
976
-
977
- @over
978
- def tile_store(a: Array[Any], i: int32, t: Any):
979
- """Stores a 1D tile to a global memory array.
980
-
981
- This method will cooperatively store a tile to global memory using all threads in the block.
982
-
983
- :param a: The destination array in global memory
984
- :param i: Offset in the destination array measured in multiples of ``n``, i.e.: ``offset=i*n``
985
- :param t: The source tile to store data from, must have the same dtype as the destination array
957
+ :returns: A tile with shape as specified and data type the same as the source array
986
958
  """
987
959
  ...
988
960
 
989
961
 
990
962
  @over
991
- def tile_store(a: Array[Any], i: int32, j: int32, t: Any):
992
- """Stores a tile to a global memory array.
963
+ def tile_store(a: Array[Any], t: Tile, offset: Tuple[int, ...]):
964
+ """Store a tile to a global memory array.
993
965
 
994
966
  This method will cooperatively store a tile to global memory using all threads in the block.
995
967
 
996
968
  :param a: The destination array in global memory
997
- :param i: Offset in the destination array measured in multiples of ``m``, i.e.: ``row=i*m``
998
- :param j: Offset in the destination array measured in multiples of ``n``, i.e.; ``col=j*n``
999
- :param t: The source tile to store data from, must have the same dtype as the destination array
969
+ :param t: The source tile to store data from, must have the same data type and number of dimensions as the destination array
970
+ :param offset: Offset in the destination array (optional)
1000
971
  """
1001
972
  ...
1002
973
 
1003
974
 
1004
975
  @over
1005
- def tile_atomic_add(a: Array[Any], x: int32, y: int32, t: Any) -> Tile:
1006
- """Atomically add a tile to the array `a`, each element will be updated atomically.
976
+ def tile_atomic_add(a: Array[Any], t: Tile, offset: Tuple[int, ...]) -> Tile:
977
+ """Atomically add a 1D tile to the array `a`, each element will be updated atomically.
1007
978
 
1008
979
  :param a: Array in global memory, should have the same ``dtype`` as the input tile
1009
- :param x: Offset in the destination array measured in multiples of ``m``, i.e.: ``i=x*M`` where ``M`` is the first tile dimension
1010
- :param y: Offset in the destination array measured in multiples of ``n``, i.e.: ``j=y*N`` where ``N`` is the second tile dimension
1011
980
  :param t: Source tile to add to the destination array
1012
- :returns: A tile with the same dimensions and type as the source tile, holding the original value of the destination elements
981
+ :param offset: Offset in the destination array (optional)
982
+ :returns: A tile with the same dimensions and data type as the source tile, holding the original value of the destination elements
1013
983
  """
1014
984
  ...
1015
985
 
1016
986
 
1017
987
  @over
1018
- def tile_view(t: Tile, i: int32, j: int32, m: int32, n: int32) -> Tile:
1019
- """Return a subrange of a given tile from coordinates (i,j) to (i+m, j+n).
988
+ def tile_view(t: Tile, offset: Tuple[int, ...], shape: Tuple[int, ...]) -> Tile:
989
+ """Return a slice of a given tile [offset, offset+shape], if shape is not specified it will be inferred from the unspecified offset dimensions.
1020
990
 
1021
991
  :param t: Input tile to extract a subrange from
1022
- :param i: Offset in the source tile along the first dimension
1023
- :param j: Offset in the source tile along the second dimensions
1024
- :param m: Size of the subrange to return along the first dimension
1025
- :param n: Size of the subrange to return along the second dimension
1026
- :returns: A tile with dimensions (m,n) and the same datatype as the input tile
992
+ :param offset: Offset in the source tile
993
+ :param shape: Shape of the returned slice
994
+ :returns: A tile with dimensions given by the specified shape or the remaining source tile dimensions
1027
995
  """
1028
996
  ...
1029
997
 
1030
998
 
1031
999
  @over
1032
- def tile_assign(dst: Tile, i: int32, j: int32, src: Tile):
1033
- """Assign a tile to a subrange of a destination tile at coordinates (i,j).
1000
+ def tile_assign(dst: Tile, src: Tile, offset: Tuple[int, ...]):
1001
+ """Assign a tile to a subrange of a destination tile.
1034
1002
 
1035
- :param t: The destination tile to assign to
1036
- :param i: Offset in the source tile along the first dimension
1037
- :param j: Offset in the source tile along the second dimensions
1003
+ :param dst: The destination tile to assign to
1038
1004
  :param src: The source tile to read values from
1005
+ :param offset: Offset in the destination tile to write to
1039
1006
  """
1040
1007
  ...
1041
1008
 
1042
1009
 
1043
1010
  @over
1044
1011
  def tile(x: Any) -> Tile:
1045
- """Constructs a new Tile from per-thread kernel values.
1012
+ """Construct a new tile from per-thread kernel values.
1046
1013
 
1047
1014
  This function converts values computed using scalar kernel code to a tile representation for input into collective operations.
1048
1015
 
1049
1016
  * If the input value is a scalar, then the resulting tile has ``shape=(1, block_dim)``
1050
1017
  * If the input value is a vector, then the resulting tile has ``shape=(length(vector), block_dim)``
1051
1018
 
1052
- :param x: A per-thread local value, e.g.: scalar, vector, or matrix.
1019
+ :param x: A per-thread local value, e.g. scalar, vector, or matrix.
1053
1020
  :returns: A tile with first dimension according to the value type length and a second dimension equal to ``block_dim``
1054
1021
 
1055
1022
  This example shows how to create a linear sequence from thread variables:
@@ -1069,7 +1036,7 @@ def tile(x: Any) -> Tile:
1069
1036
 
1070
1037
  .. code-block:: text
1071
1038
 
1072
- tile(m=1, n=16, storage=register) = [[0 2 4 6 8 ...]]
1039
+ [0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30] = tile(shape=(16), storage=register)
1073
1040
 
1074
1041
 
1075
1042
  """
@@ -1077,16 +1044,16 @@ def tile(x: Any) -> Tile:
1077
1044
 
1078
1045
 
1079
1046
  @over
1080
- def untile(a: Any) -> Scalar:
1081
- """Convert a Tile back to per-thread values.
1047
+ def untile(a: Tile) -> Scalar:
1048
+ """Convert a tile back to per-thread values.
1082
1049
 
1083
1050
  This function converts a block-wide tile back to per-thread values.
1084
1051
 
1085
- * If the input tile is 1-dimensional then the resulting value will be a per-thread scalar
1086
- * If the input tile is 2-dimensional then the resulting value will be a per-thread vector of length M
1052
+ * If the input tile is 1D, then the resulting value will be a per-thread scalar
1053
+ * If the input tile is 2D, then the resulting value will be a per-thread vector of length M
1087
1054
 
1088
1055
  :param a: A tile with dimensions ``shape=(M, block_dim)``
1089
- :returns: A single value per-thread with the same dtype as the tile
1056
+ :returns: A single value per-thread with the same data type as the tile
1090
1057
 
1091
1058
  This example shows how to create a linear sequence from thread variables:
1092
1059
 
@@ -1100,7 +1067,7 @@ def untile(a: Any) -> Scalar:
1100
1067
  t = wp.tile(i) * 2
1101
1068
 
1102
1069
  # convert back to per-thread values
1103
- s = wp.untile()
1070
+ s = wp.untile(t)
1104
1071
 
1105
1072
  print(s)
1106
1073
 
@@ -1122,27 +1089,12 @@ def untile(a: Any) -> Scalar:
1122
1089
  ...
1123
1090
 
1124
1091
 
1125
- @over
1126
- def tile_extract(a: Tile, i: int32, j: int32) -> Scalar:
1127
- """Extracts a single element from the tile and returns it as a scalar type.
1128
-
1129
- This function will extract an element from the tile and broadcast its value to all threads in the block.
1130
-
1131
- Note that this may incur additional synchronization if the source tile is a register tile.
1132
-
1133
- :param a: Tile to extract the element from
1134
- :param i: Coordinate of element on first dimension
1135
- :param j: Coordinate of element on the second dimension
1136
- :returns: The value of the element at the specified tile location, with the same type as the input tile's per-element dtype
1137
- """
1138
- ...
1139
-
1140
-
1141
1092
  @over
1142
1093
  def tile_transpose(a: Tile) -> Tile:
1143
1094
  """Transpose a tile.
1144
1095
 
1145
- For shared memory tiles this operation will alias the input tile, register tiles will first be transferred to shared memory before transposition.
1096
+ For shared memory tiles, this operation will alias the input tile.
1097
+ Register tiles will first be transferred to shared memory before transposition.
1146
1098
 
1147
1099
  :param a: Tile to transpose with ``shape=(M,N)``
1148
1100
  :returns: Tile with ``shape=(N,M)``
@@ -1151,12 +1103,15 @@ def tile_transpose(a: Tile) -> Tile:
1151
1103
 
1152
1104
 
1153
1105
  @over
1154
- def tile_broadcast(a: Tile, m: int32, n: int32) -> Tile:
1106
+ def tile_broadcast(a: Tile, shape: Tuple[int, ...]) -> Tile:
1155
1107
  """Broadcast a tile.
1156
1108
 
1157
- This method will attempt to broadcast the input tile ``a`` to the destination shape (m, n), broadcasting follows NumPy broadcast rules.
1109
+ This function will attempt to broadcast the input tile ``a`` to the destination shape (m, n).
1110
+
1111
+ Broadcasting follows NumPy broadcast rules.
1158
1112
 
1159
1113
  :param a: Tile to broadcast
1114
+ :param shape: The shape to broadcast to
1160
1115
  :returns: Tile with broadcast ``shape=(m, n)``
1161
1116
  """
1162
1117
  ...
@@ -1167,7 +1122,7 @@ def tile_sum(a: Tile) -> Tile:
1167
1122
  """Cooperatively compute the sum of the tile elements using all threads in the block.
1168
1123
 
1169
1124
  :param a: The tile to compute the sum of
1170
- :returns: A single-element tile with dimensions of (1,1) holding the sum
1125
+ :returns: A single-element tile holding the sum
1171
1126
 
1172
1127
  Example:
1173
1128
 
@@ -1175,19 +1130,19 @@ def tile_sum(a: Tile) -> Tile:
1175
1130
 
1176
1131
  @wp.kernel
1177
1132
  def compute():
1178
- t = wp.tile_ones(dtype=float, m=16, n=16)
1133
+ t = wp.tile_ones(dtype=float, shape=(16, 16))
1179
1134
  s = wp.tile_sum(t)
1180
1135
 
1181
- print(t)
1136
+ print(s)
1182
1137
 
1183
1138
 
1184
- wp.launch(compute, dim=[64], inputs=[])
1139
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
1185
1140
 
1186
1141
  Prints:
1187
1142
 
1188
1143
  .. code-block:: text
1189
1144
 
1190
- tile(m=1, n=1, storage=register) = [[256]]
1145
+ [256] = tile(shape=(1), storage=register)
1191
1146
 
1192
1147
 
1193
1148
  """
@@ -1199,7 +1154,7 @@ def tile_min(a: Tile) -> Tile:
1199
1154
  """Cooperatively compute the minimum of the tile elements using all threads in the block.
1200
1155
 
1201
1156
  :param a: The tile to compute the minimum of
1202
- :returns: A single-element tile with dimensions of (1,1) holding the minimum value
1157
+ :returns: A single-element tile holding the minimum value
1203
1158
 
1204
1159
  Example:
1205
1160
 
@@ -1207,19 +1162,19 @@ def tile_min(a: Tile) -> Tile:
1207
1162
 
1208
1163
  @wp.kernel
1209
1164
  def compute():
1210
- t = wp.tile_arange(start=--10, stop=10, dtype=float)
1165
+ t = wp.tile_arange(64, 128)
1211
1166
  s = wp.tile_min(t)
1212
1167
 
1213
- print(t)
1168
+ print(s)
1214
1169
 
1215
1170
 
1216
- wp.launch(compute, dim=[64], inputs=[])
1171
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
1217
1172
 
1218
1173
  Prints:
1219
1174
 
1220
1175
  .. code-block:: text
1221
1176
 
1222
- tile(m=1, n=1, storage=register) = [[-10]]
1177
+ [64] = tile(shape=(1), storage=register)
1223
1178
 
1224
1179
 
1225
1180
  """
@@ -1231,7 +1186,7 @@ def tile_max(a: Tile) -> Tile:
1231
1186
  """Cooperatively compute the maximum of the tile elements using all threads in the block.
1232
1187
 
1233
1188
  :param a: The tile to compute the maximum from
1234
- :returns: A single-element tile with dimensions of (1,1) holding the maximum value
1189
+ :returns: A single-element tile holding the maximum value
1235
1190
 
1236
1191
  Example:
1237
1192
 
@@ -1239,19 +1194,19 @@ def tile_max(a: Tile) -> Tile:
1239
1194
 
1240
1195
  @wp.kernel
1241
1196
  def compute():
1242
- t = wp.tile_arange(start=--10, stop=10, dtype=float)
1243
- s = wp.tile_min(t)
1197
+ t = wp.tile_arange(64, 128)
1198
+ s = wp.tile_max(t)
1244
1199
 
1245
- print(t)
1200
+ print(s)
1246
1201
 
1247
1202
 
1248
- wp.launch(compute, dim=[64], inputs=[])
1203
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
1249
1204
 
1250
1205
  Prints:
1251
1206
 
1252
1207
  .. code-block:: text
1253
1208
 
1254
- tile(m=1, n=1, storage=register) = [[10]]
1209
+ [127] = tile(shape=(1), storage=register)
1255
1210
 
1256
1211
 
1257
1212
  """
@@ -1259,14 +1214,14 @@ def tile_max(a: Tile) -> Tile:
1259
1214
 
1260
1215
 
1261
1216
  @over
1262
- def tile_reduce(op: Callable, a: Any) -> Tile:
1217
+ def tile_reduce(op: Callable, a: Tile) -> Tile:
1263
1218
  """Apply a custom reduction operator across the tile.
1264
1219
 
1265
1220
  This function cooperatively performs a reduction using the provided operator across the tile.
1266
1221
 
1267
1222
  :param op: A callable function that accepts two arguments and returns one argument, may be a user function or builtin
1268
- :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
1269
- :returns: A single-element tile with ``shape=(1,1)`` with the same datatype as the input tile.
1223
+ :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type
1224
+ :returns: A single-element tile with the same data type as the input tile.
1270
1225
 
1271
1226
  Example:
1272
1227
 
@@ -1280,27 +1235,27 @@ def tile_reduce(op: Callable, a: Any) -> Tile:
1280
1235
  print(s)
1281
1236
 
1282
1237
 
1283
- wp.launch(factorial, dim=[16], inputs=[], block_dim=16)
1238
+ wp.launch_tiled(factorial, dim=[1], inputs=[], block_dim=16)
1284
1239
 
1285
1240
  Prints:
1286
1241
 
1287
1242
  .. code-block:: text
1288
1243
 
1289
- tile(m=1, n=1, storage=register) = [[362880]]
1244
+ [362880] = tile(shape=(1), storage=register)
1290
1245
 
1291
1246
  """
1292
1247
  ...
1293
1248
 
1294
1249
 
1295
1250
  @over
1296
- def tile_map(op: Callable, a: Any) -> Tile:
1251
+ def tile_map(op: Callable, a: Tile) -> Tile:
1297
1252
  """Apply a unary function onto the tile.
1298
1253
 
1299
1254
  This function cooperatively applies a unary function to each element of the tile using all threads in the block.
1300
1255
 
1301
1256
  :param op: A callable function that accepts one argument and returns one argument, may be a user function or builtin
1302
- :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
1303
- :returns: A tile with the same dimensions and datatype as the input tile.
1257
+ :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type
1258
+ :returns: A tile with the same dimensions and data type as the input tile.
1304
1259
 
1305
1260
  Example:
1306
1261
 
@@ -1314,20 +1269,20 @@ def tile_map(op: Callable, a: Any) -> Tile:
1314
1269
  print(s)
1315
1270
 
1316
1271
 
1317
- wp.launch(compute, dim=[16], inputs=[])
1272
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16)
1318
1273
 
1319
1274
  Prints:
1320
1275
 
1321
1276
  .. code-block:: text
1322
1277
 
1323
- tile(m=1, n=10, storage=register) = [[0 0.0998334 0.198669 0.29552 ...]]
1278
+ [0 0.0998334 0.198669 0.29552 0.389418 0.479426 0.564642 0.644218 0.717356 0.783327] = tile(shape=(10), storage=register)
1324
1279
 
1325
1280
  """
1326
1281
  ...
1327
1282
 
1328
1283
 
1329
1284
  @over
1330
- def tile_map(op: Callable, a: Any, b: Any) -> Tile:
1285
+ def tile_map(op: Callable, a: Tile, b: Tile) -> Tile:
1331
1286
  """Apply a binary function onto the tile.
1332
1287
 
1333
1288
  This function cooperatively applies a binary function to each element of the tiles using all threads in the block.
@@ -1345,20 +1300,20 @@ def tile_map(op: Callable, a: Any, b: Any) -> Tile:
1345
1300
  @wp.kernel
1346
1301
  def compute():
1347
1302
  a = wp.tile_arange(0.0, 1.0, 0.1, dtype=float)
1348
- b = wp.tile_ones(m=1, n=10, dtype=float)
1303
+ b = wp.tile_ones(shape=10, dtype=float)
1349
1304
 
1350
1305
  s = wp.tile_map(wp.add, a, b)
1351
1306
 
1352
1307
  print(s)
1353
1308
 
1354
1309
 
1355
- wp.launch(compute, dim=[16], inputs=[])
1310
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16)
1356
1311
 
1357
1312
  Prints:
1358
1313
 
1359
1314
  .. code-block:: text
1360
1315
 
1361
- tile(m=1, n=10, storage=register) = [[1 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9]]
1316
+ [1 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9] = tile(shape=(10), storage=register)
1362
1317
  """
1363
1318
  ...
1364
1319
 
@@ -1374,6 +1329,9 @@ def mlp(
1374
1329
  ):
1375
1330
  """Evaluate a multi-layer perceptron (MLP) layer in the form: ``out = act(weights*x + bias)``.
1376
1331
 
1332
+ .. deprecated:: 1.6
1333
+ Use :doc:`tile primitives </modules/tiles>` instead.
1334
+
1377
1335
  :param weights: A layer's network weights with dimensions ``(m, n)``.
1378
1336
  :param bias: An array with dimensions ``(n)``.
1379
1337
  :param activation: A ``wp.func`` function that takes a single scalar float as input and returns a scalar float as output
@@ -2602,6 +2560,12 @@ def sub(a: Transformation[Scalar], b: Transformation[Scalar]) -> Transformation[
2602
2560
  ...
2603
2561
 
2604
2562
 
2563
+ @over
2564
+ def sub(a: Tile, b: Tile) -> Tile:
2565
+ """Subtract each element b from a"""
2566
+ ...
2567
+
2568
+
2605
2569
  @over
2606
2570
  def bit_and(a: Int, b: Int) -> Int:
2607
2571
  """ """
@@ -2908,6 +2872,12 @@ def unot(a: Array[Any]) -> bool:
2908
2872
  ...
2909
2873
 
2910
2874
 
2875
+ @over
2876
+ def tile_diag_add(a: Tile, d: Tile) -> Tile:
2877
+ """Add a square matrix and a diagonal matrix 'd' represented as a 1D tile"""
2878
+ ...
2879
+
2880
+
2911
2881
  @over
2912
2882
  def tile_matmul(a: Tile, b: Tile, out: Tile) -> Tile:
2913
2883
  """Computes the matrix product and accumulates ``out += a*b``.
@@ -2952,6 +2922,8 @@ def tile_fft(inout: Tile) -> Tile:
2952
2922
 
2953
2923
  This function cooperatively computes the forward FFT on a tile of data inplace, treating each row individually.
2954
2924
 
2925
+ Note that computing the adjoint is not yet supported.
2926
+
2955
2927
  Supported datatypes are:
2956
2928
  * vec2f, vec2d
2957
2929
 
@@ -2966,6 +2938,8 @@ def tile_ifft(inout: Tile) -> Tile:
2966
2938
 
2967
2939
  This function cooperatively computes the inverse FFT on a tile of data inplace, treating each row individually.
2968
2940
 
2941
+ Note that computing the adjoint is not yet supported.
2942
+
2969
2943
  Supported datatypes are:
2970
2944
  * vec2f, vec2d
2971
2945
 
@@ -2974,9 +2948,43 @@ def tile_ifft(inout: Tile) -> Tile:
2974
2948
  ...
2975
2949
 
2976
2950
 
2951
+ @over
2952
+ def tile_cholesky(A: Tile) -> Tile:
2953
+ """Compute the Cholesky factorization L of a matrix A.
2954
+ L is lower triangular and satisfies LL^T = A.
2955
+
2956
+ Note that computing the adjoint is not yet supported.
2957
+
2958
+ Supported datatypes are:
2959
+ * float32
2960
+ * float64
2961
+
2962
+ :param A: A square, symmetric positive-definite, matrix.
2963
+ :returns L: A square, lower triangular, matrix, such that LL^T = A
2964
+ """
2965
+ ...
2966
+
2967
+
2968
+ @over
2969
+ def tile_cholesky_solve(L: Tile, x: Tile):
2970
+ """With L such that LL^T = A, solve for x in Ax = y
2971
+
2972
+ Note that computing the adjoint is not yet supported.
2973
+
2974
+ Supported datatypes are:
2975
+ * float32
2976
+ * float64
2977
+
2978
+ :param L: A square, lower triangular, matrix, such that LL^T = A
2979
+ :param x: An 1D tile of length M
2980
+ :returns y: An 1D tile of length M such that LL^T y = x
2981
+ """
2982
+ ...
2983
+
2984
+
2977
2985
  @over
2978
2986
  def static(expr: Any) -> Any:
2979
- """Evaluates a static Python expression and replaces it with its result.
2987
+ """Evaluate a static Python expression and replaces it with its result.
2980
2988
 
2981
2989
  See the :ref:`code generation guide <static_expressions>` for more details.
2982
2990
 
@@ -2986,3 +2994,128 @@ def static(expr: Any) -> Any:
2986
2994
  (excluding Warp arrays since they cannot be created in a Warp kernel at the moment).
2987
2995
  """
2988
2996
  ...
2997
+
2998
+
2999
+ @over
3000
+ def len(a: Vector[Any, Scalar]) -> int:
3001
+ """Return the number of elements in a vector."""
3002
+ ...
3003
+
3004
+
3005
+ @over
3006
+ def len(a: Quaternion[Scalar]) -> int:
3007
+ """Return the number of elements in a quaternion."""
3008
+ ...
3009
+
3010
+
3011
+ @over
3012
+ def len(a: Matrix[Any, Any, Scalar]) -> int:
3013
+ """Return the number of rows in a matrix."""
3014
+ ...
3015
+
3016
+
3017
+ @over
3018
+ def len(a: Transformation[Float]) -> int:
3019
+ """Return the number of elements in a transformation."""
3020
+ ...
3021
+
3022
+
3023
+ @over
3024
+ def len(a: Array[Any]) -> int:
3025
+ """Return the size of the first dimension in an array."""
3026
+ ...
3027
+
3028
+
3029
+ @over
3030
+ def len(a: Tile) -> int:
3031
+ """Return the number of rows in a tile."""
3032
+ ...
3033
+
3034
+
3035
+ @over
3036
+ def norm_l1(v: Any):
3037
+ """Computes the L1 norm of a vector v.
3038
+
3039
+ .. math:: \|v\|_1 = \sum_i |v_i|
3040
+
3041
+ Args:
3042
+ v (Vector[Any,Float]): The vector to compute the L1 norm of.
3043
+
3044
+ Returns:
3045
+ float: The L1 norm of the vector.
3046
+ """
3047
+ ...
3048
+
3049
+
3050
+ @over
3051
+ def norm_l2(v: Any):
3052
+ """Computes the L2 norm of a vector v.
3053
+
3054
+ .. math:: \|v\|_2 = \sqrt{\sum_i v_i^2}
3055
+
3056
+ Args:
3057
+ v (Vector[Any,Float]): The vector to compute the L2 norm of.
3058
+
3059
+ Returns:
3060
+ float: The L2 norm of the vector.
3061
+ """
3062
+ ...
3063
+
3064
+
3065
+ @over
3066
+ def norm_huber(v: Any, delta: float):
3067
+ """Computes the Huber norm of a vector v with a given delta.
3068
+
3069
+ .. math::
3070
+ H(v) = \begin{cases} \frac{1}{2} \|v\|^2 & \text{if } \|v\| \leq \delta \\ \delta(\|v\| - \frac{1}{2}\delta) & \text{otherwise} \end{cases}
3071
+
3072
+ .. image:: /img/norm_huber.svg
3073
+ :align: center
3074
+
3075
+ Args:
3076
+ v (Vector[Any,Float]): The vector to compute the Huber norm of.
3077
+ delta (float): The threshold value, defaults to 1.0.
3078
+
3079
+ Returns:
3080
+ float: The Huber norm of the vector.
3081
+ """
3082
+ ...
3083
+
3084
+
3085
+ @over
3086
+ def norm_pseudo_huber(v: Any, delta: float):
3087
+ """Computes the "pseudo" Huber norm of a vector v with a given delta.
3088
+
3089
+ .. math::
3090
+ H^\prime(v) = \delta \sqrt{1 + \frac{\|v\|^2}{\delta^2}}
3091
+
3092
+ .. image:: /img/norm_pseudo_huber.svg
3093
+ :align: center
3094
+
3095
+ Args:
3096
+ v (Vector[Any,Float]): The vector to compute the Huber norm of.
3097
+ delta (float): The threshold value, defaults to 1.0.
3098
+
3099
+ Returns:
3100
+ float: The Huber norm of the vector.
3101
+ """
3102
+ ...
3103
+
3104
+
3105
+ @over
3106
+ def smooth_normalize(v: Any, delta: float):
3107
+ """Normalizes a vector using the pseudo-Huber norm.
3108
+
3109
+ See :func:`norm_pseudo_huber`.
3110
+
3111
+ .. math::
3112
+ \frac{v}{H^\prime(v)}
3113
+
3114
+ Args:
3115
+ v (Vector[Any,Float]): The vector to normalize.
3116
+ delta (float): The threshold value, defaults to 1.0.
3117
+
3118
+ Returns:
3119
+ Vector[Any,Float]: The normalized vector.
3120
+ """
3121
+ ...