warp-lang 1.4.2__py3-none-manylinux2014_aarch64.whl → 1.5.0__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (158) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1783 -2
  8. warp/codegen.py +177 -45
  9. warp/config.py +2 -2
  10. warp/context.py +321 -73
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/fem/example_adaptive_grid.py +37 -10
  17. warp/examples/fem/example_apic_fluid.py +3 -2
  18. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  19. warp/examples/fem/example_deformed_geometry.py +1 -1
  20. warp/examples/fem/example_diffusion_3d.py +47 -4
  21. warp/examples/fem/example_distortion_energy.py +220 -0
  22. warp/examples/fem/example_magnetostatics.py +127 -85
  23. warp/examples/fem/example_nonconforming_contact.py +5 -5
  24. warp/examples/fem/example_stokes.py +3 -1
  25. warp/examples/fem/example_streamlines.py +12 -19
  26. warp/examples/fem/utils.py +38 -15
  27. warp/examples/sim/example_cloth.py +2 -25
  28. warp/examples/sim/example_quadruped.py +2 -1
  29. warp/examples/tile/example_tile_convolution.py +58 -0
  30. warp/examples/tile/example_tile_fft.py +47 -0
  31. warp/examples/tile/example_tile_filtering.py +105 -0
  32. warp/examples/tile/example_tile_matmul.py +79 -0
  33. warp/examples/tile/example_tile_mlp.py +375 -0
  34. warp/fem/__init__.py +8 -0
  35. warp/fem/cache.py +16 -12
  36. warp/fem/dirichlet.py +1 -1
  37. warp/fem/domain.py +44 -1
  38. warp/fem/field/__init__.py +1 -2
  39. warp/fem/field/field.py +31 -19
  40. warp/fem/field/nodal_field.py +101 -49
  41. warp/fem/field/virtual.py +794 -0
  42. warp/fem/geometry/__init__.py +2 -2
  43. warp/fem/geometry/deformed_geometry.py +3 -105
  44. warp/fem/geometry/element.py +13 -0
  45. warp/fem/geometry/geometry.py +165 -5
  46. warp/fem/geometry/grid_2d.py +3 -6
  47. warp/fem/geometry/grid_3d.py +31 -28
  48. warp/fem/geometry/hexmesh.py +3 -46
  49. warp/fem/geometry/nanogrid.py +3 -2
  50. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  51. warp/fem/geometry/tetmesh.py +2 -43
  52. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  53. warp/fem/integrate.py +683 -261
  54. warp/fem/linalg.py +404 -0
  55. warp/fem/operator.py +101 -18
  56. warp/fem/polynomial.py +5 -5
  57. warp/fem/quadrature/quadrature.py +45 -21
  58. warp/fem/space/__init__.py +45 -11
  59. warp/fem/space/basis_function_space.py +451 -0
  60. warp/fem/space/basis_space.py +58 -11
  61. warp/fem/space/function_space.py +146 -5
  62. warp/fem/space/grid_2d_function_space.py +80 -66
  63. warp/fem/space/grid_3d_function_space.py +113 -68
  64. warp/fem/space/hexmesh_function_space.py +96 -108
  65. warp/fem/space/nanogrid_function_space.py +62 -110
  66. warp/fem/space/quadmesh_function_space.py +208 -0
  67. warp/fem/space/shape/__init__.py +45 -7
  68. warp/fem/space/shape/cube_shape_function.py +328 -54
  69. warp/fem/space/shape/shape_function.py +10 -1
  70. warp/fem/space/shape/square_shape_function.py +328 -60
  71. warp/fem/space/shape/tet_shape_function.py +269 -19
  72. warp/fem/space/shape/triangle_shape_function.py +238 -19
  73. warp/fem/space/tetmesh_function_space.py +69 -37
  74. warp/fem/space/topology.py +38 -0
  75. warp/fem/space/trimesh_function_space.py +179 -0
  76. warp/fem/utils.py +6 -331
  77. warp/jax_experimental.py +3 -1
  78. warp/native/array.h +15 -0
  79. warp/native/builtin.h +66 -26
  80. warp/native/bvh.h +4 -0
  81. warp/native/coloring.cpp +600 -0
  82. warp/native/cuda_util.cpp +14 -0
  83. warp/native/cuda_util.h +2 -1
  84. warp/native/fabric.h +8 -0
  85. warp/native/hashgrid.h +4 -0
  86. warp/native/marching.cu +8 -0
  87. warp/native/mat.h +14 -3
  88. warp/native/mathdx.cpp +59 -0
  89. warp/native/mesh.h +4 -0
  90. warp/native/range.h +13 -1
  91. warp/native/reduce.cpp +9 -1
  92. warp/native/reduce.cu +7 -0
  93. warp/native/runlength_encode.cpp +9 -1
  94. warp/native/runlength_encode.cu +7 -1
  95. warp/native/scan.cpp +8 -0
  96. warp/native/scan.cu +8 -0
  97. warp/native/scan.h +8 -1
  98. warp/native/sparse.cpp +8 -0
  99. warp/native/sparse.cu +8 -0
  100. warp/native/temp_buffer.h +7 -0
  101. warp/native/tile.h +1857 -0
  102. warp/native/tile_gemm.h +341 -0
  103. warp/native/tile_reduce.h +210 -0
  104. warp/native/volume_builder.cu +8 -0
  105. warp/native/volume_builder.h +8 -0
  106. warp/native/warp.cpp +10 -2
  107. warp/native/warp.cu +369 -15
  108. warp/native/warp.h +12 -2
  109. warp/optim/adam.py +39 -4
  110. warp/paddle.py +29 -12
  111. warp/render/render_opengl.py +137 -65
  112. warp/sim/graph_coloring.py +292 -0
  113. warp/sim/integrator_euler.py +4 -2
  114. warp/sim/integrator_featherstone.py +115 -44
  115. warp/sim/integrator_vbd.py +6 -0
  116. warp/sim/model.py +88 -15
  117. warp/stubs.py +569 -4
  118. warp/tape.py +12 -7
  119. warp/tests/assets/pixel.npy +0 -0
  120. warp/tests/aux_test_instancing_gc.py +18 -0
  121. warp/tests/test_array.py +39 -0
  122. warp/tests/test_codegen.py +81 -1
  123. warp/tests/test_codegen_instancing.py +30 -0
  124. warp/tests/test_collision.py +110 -0
  125. warp/tests/test_coloring.py +241 -0
  126. warp/tests/test_context.py +34 -0
  127. warp/tests/test_examples.py +18 -4
  128. warp/tests/test_fem.py +453 -113
  129. warp/tests/test_func.py +13 -0
  130. warp/tests/test_generics.py +52 -0
  131. warp/tests/test_iter.py +68 -0
  132. warp/tests/test_mat_scalar_ops.py +1 -1
  133. warp/tests/test_mesh_query_point.py +1 -1
  134. warp/tests/test_module_hashing.py +23 -0
  135. warp/tests/test_paddle.py +27 -87
  136. warp/tests/test_print.py +56 -1
  137. warp/tests/test_spatial.py +1 -1
  138. warp/tests/test_tile.py +700 -0
  139. warp/tests/test_tile_mathdx.py +144 -0
  140. warp/tests/test_tile_mlp.py +383 -0
  141. warp/tests/test_tile_reduce.py +374 -0
  142. warp/tests/test_tile_shared_memory.py +190 -0
  143. warp/tests/test_vbd.py +12 -20
  144. warp/tests/test_volume.py +43 -0
  145. warp/tests/unittest_suites.py +19 -2
  146. warp/tests/unittest_utils.py +4 -0
  147. warp/types.py +338 -72
  148. warp/utils.py +22 -1
  149. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
  150. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/RECORD +153 -126
  151. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
  152. warp/fem/field/test.py +0 -180
  153. warp/fem/field/trial.py +0 -183
  154. warp/fem/space/collocated_function_space.py +0 -102
  155. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  156. warp/fem/space/trimesh_2d_function_space.py +0 -153
  157. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
  158. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
warp/stubs.py CHANGED
@@ -11,9 +11,6 @@ Length = TypeVar("Length", bound=int)
11
11
  Rows = TypeVar("Rows", bound=int)
12
12
  Cols = TypeVar("Cols", bound=int)
13
13
  DType = TypeVar("DType")
14
- Int = TypeVar("Int")
15
- Float = TypeVar("Float")
16
- Scalar = TypeVar("Scalar")
17
14
  Vector = Generic[Length, Scalar]
18
15
  Matrix = Generic[Rows, Cols, Scalar]
19
16
  Quaternion = Generic[Float]
@@ -39,6 +36,8 @@ from warp.types import transform, transformh, transformf, transformd
39
36
  from warp.types import spatial_vector, spatial_vectorh, spatial_vectorf, spatial_vectord
40
37
  from warp.types import spatial_matrix, spatial_matrixh, spatial_matrixf, spatial_matrixd
41
38
 
39
+ from warp.types import Int, Float, Scalar
40
+
42
41
  from warp.types import Bvh, Mesh, HashGrid, Volume, MarchingCubes
43
42
  from warp.types import BvhQuery, HashGridQuery, MeshQueryAABB, MeshQueryPoint, MeshQueryRay
44
43
 
@@ -67,6 +66,7 @@ from warp.context import (
67
66
  copy,
68
67
  from_numpy,
69
68
  launch,
69
+ launch_tiled,
70
70
  synchronize,
71
71
  force_load,
72
72
  load_module,
@@ -894,6 +894,475 @@ def spatial_mass(
894
894
  ...
895
895
 
896
896
 
897
+ @over
898
+ def tile_zeros(m: int32, n: int32, dtype: Any, storage: str) -> Tile:
899
+ """Allocates a tile of zero-initialized items.
900
+
901
+ :param m: Size of the first dimension of the output tile
902
+ :param n: Size of the second dimension of the output tile
903
+ :param dtype: Datatype of output tile's elements
904
+ :param storage: The storage location for the tile: ``"register"`` for registers
905
+ (default) or ``"shared"`` for shared memory.
906
+ :returns: A zero-initialized tile with ``shape=(m,n)`` and the specified datatype
907
+ """
908
+ ...
909
+
910
+
911
+ @over
912
+ def tile_ones(m: int32, n: int32, dtype: Any, storage: str) -> Tile:
913
+ """Allocates a tile of one-initialized items.
914
+
915
+ :param m: Size of the first dimension of the output tile
916
+ :param n: Size of the second dimension of the output tile
917
+ :param dtype: Datatype of output tile's elements
918
+ :param storage: The storage location for the tile: ``"register"`` for registers
919
+ (default) or ``"shared"`` for shared memory.
920
+ :returns: A one-initialized tile with ``shape=(m,n)`` and the specified dtype
921
+ """
922
+ ...
923
+
924
+
925
+ @over
926
+ def tile_arange(*args: Scalar, dtype: Any, storage: str) -> Tile:
927
+ """Generates a tile of linearly spaced elements.
928
+
929
+ :param args: Variable-length positional arguments, interpreted as:
930
+
931
+ - ``(stop,)``: Generates values from ``0`` to ``stop - 1``
932
+ - ``(start, stop)``: Generates values from ``start`` to ``stop - 1``
933
+ - ``(start, stop, step)``: Generates values from ``start`` to ``stop - 1`` with a step size
934
+
935
+ :param dtype: Datatype of output tile's elements (optional, default: int)
936
+ :param storage: The storage location for the tile: ``"register"`` for registers
937
+ (default) or ``"shared"`` for shared memory.
938
+ :returns: A tile with ``shape=(1,n)`` with linearly spaced elements of specified dtype
939
+ """
940
+ ...
941
+
942
+
943
+ @over
944
+ def tile_load(a: Array[Any], i: int32, n: int32, storage: str) -> Tile:
945
+ """Loads a 1D tile from a global memory array.
946
+
947
+ This method will cooperatively load a tile from global memory using all threads in the block.
948
+
949
+ :param a: The source array in global memory
950
+ :param i: Offset in the source array measured in multiples of ``n``, i.e.: ``offset=i*n``
951
+ :param n: The number of elements in the tile
952
+ :param storage: The storage location for the tile: ``"register"`` for registers
953
+ (default) or ``"shared"`` for shared memory.
954
+ :returns: A tile with ``shape=(1,n)`` and dtype the same as the source array
955
+ """
956
+ ...
957
+
958
+
959
+ @over
960
+ def tile_load(a: Array[Any], i: int32, j: int32, m: int32, n: int32, storage: str) -> Tile:
961
+ """Loads a 2D tile from a global memory array.
962
+
963
+ This method will cooperatively load a tile from global memory using all threads in the block.
964
+
965
+ :param a: The source array in global memory
966
+ :param i: Offset in the source array measured in multiples of ``m``, i.e.: ``row=i*m``
967
+ :param j: Offset in the source array measured in multiples of ``n``, i.e.; ``col=j*n``
968
+ :param m: The size of the tile's first dimension
969
+ :param n: The size of the tile's second dimension
970
+ :param storage: The storage location for the tile: ``"register"`` for registers
971
+ (default) or ``"shared"`` for shared memory.
972
+ :returns: A tile with ``shape=(m,n)`` and dtype the same as the source array
973
+ """
974
+ ...
975
+
976
+
977
+ @over
978
+ def tile_store(a: Array[Any], i: int32, t: Any):
979
+ """Stores a 1D tile to a global memory array.
980
+
981
+ This method will cooperatively store a tile to global memory using all threads in the block.
982
+
983
+ :param a: The destination array in global memory
984
+ :param i: Offset in the destination array measured in multiples of ``n``, i.e.: ``offset=i*n``
985
+ :param t: The source tile to store data from, must have the same dtype as the destination array
986
+ """
987
+ ...
988
+
989
+
990
+ @over
991
+ def tile_store(a: Array[Any], i: int32, j: int32, t: Any):
992
+ """Stores a tile to a global memory array.
993
+
994
+ This method will cooperatively store a tile to global memory using all threads in the block.
995
+
996
+ :param a: The destination array in global memory
997
+ :param i: Offset in the destination array measured in multiples of ``m``, i.e.: ``row=i*m``
998
+ :param j: Offset in the destination array measured in multiples of ``n``, i.e.; ``col=j*n``
999
+ :param t: The source tile to store data from, must have the same dtype as the destination array
1000
+ """
1001
+ ...
1002
+
1003
+
1004
+ @over
1005
+ def tile_atomic_add(a: Array[Any], x: int32, y: int32, t: Any) -> Tile:
1006
+ """Atomically add a tile to the array `a`, each element will be updated atomically.
1007
+
1008
+ :param a: Array in global memory, should have the same ``dtype`` as the input tile
1009
+ :param x: Offset in the destination array measured in multiples of ``m``, i.e.: ``i=x*M`` where ``M`` is the first tile dimension
1010
+ :param y: Offset in the destination array measured in multiples of ``n``, i.e.: ``j=y*N`` where ``N`` is the second tile dimension
1011
+ :param t: Source tile to add to the destination array
1012
+ :returns: A tile with the same dimensions and type as the source tile, holding the original value of the destination elements
1013
+ """
1014
+ ...
1015
+
1016
+
1017
+ @over
1018
+ def tile_view(t: Tile, i: int32, j: int32, m: int32, n: int32) -> Tile:
1019
+ """Return a subrange of a given tile from coordinates (i,j) to (i+m, j+n).
1020
+
1021
+ :param t: Input tile to extract a subrange from
1022
+ :param i: Offset in the source tile along the first dimension
1023
+ :param j: Offset in the source tile along the second dimensions
1024
+ :param m: Size of the subrange to return along the first dimension
1025
+ :param n: Size of the subrange to return along the second dimension
1026
+ :returns: A tile with dimensions (m,n) and the same datatype as the input tile
1027
+ """
1028
+ ...
1029
+
1030
+
1031
+ @over
1032
+ def tile_assign(dst: Tile, i: int32, j: int32, src: Tile):
1033
+ """Assign a tile to a subrange of a destination tile at coordinates (i,j).
1034
+
1035
+ :param t: The destination tile to assign to
1036
+ :param i: Offset in the source tile along the first dimension
1037
+ :param j: Offset in the source tile along the second dimensions
1038
+ :param src: The source tile to read values from
1039
+ """
1040
+ ...
1041
+
1042
+
1043
+ @over
1044
+ def tile(x: Any) -> Tile:
1045
+ """Constructs a new Tile from per-thread kernel values.
1046
+
1047
+ This function converts values computed using scalar kernel code to a tile representation for input into collective operations.
1048
+
1049
+ * If the input value is a scalar, then the resulting tile has ``shape=(1, block_dim)``
1050
+ * If the input value is a vector, then the resulting tile has ``shape=(length(vector), block_dim)``
1051
+
1052
+ :param x: A per-thread local value, e.g.: scalar, vector, or matrix.
1053
+ :returns: A tile with first dimension according to the value type length and a second dimension equal to ``block_dim``
1054
+
1055
+ This example shows how to create a linear sequence from thread variables:
1056
+
1057
+ .. code-block:: python
1058
+
1059
+ @wp.kernel
1060
+ def compute():
1061
+ i = wp.tid()
1062
+ t = wp.tile(i * 2)
1063
+ print(t)
1064
+
1065
+
1066
+ wp.launch(compute, dim=16, inputs=[], block_dim=16)
1067
+
1068
+ Prints:
1069
+
1070
+ .. code-block:: text
1071
+
1072
+ tile(m=1, n=16, storage=register) = [[0 2 4 6 8 ...]]
1073
+
1074
+
1075
+ """
1076
+ ...
1077
+
1078
+
1079
+ @over
1080
+ def untile(a: Any) -> Scalar:
1081
+ """Convert a Tile back to per-thread values.
1082
+
1083
+ This function converts a block-wide tile back to per-thread values.
1084
+
1085
+ * If the input tile is 1-dimensional then the resulting value will be a per-thread scalar
1086
+ * If the input tile is 2-dimensional then the resulting value will be a per-thread vector of length M
1087
+
1088
+ :param a: A tile with dimensions ``shape=(M, block_dim)``
1089
+ :returns: A single value per-thread with the same dtype as the tile
1090
+
1091
+ This example shows how to create a linear sequence from thread variables:
1092
+
1093
+ .. code-block:: python
1094
+
1095
+ @wp.kernel
1096
+ def compute():
1097
+ i = wp.tid()
1098
+
1099
+ # create block-wide tile
1100
+ t = wp.tile(i) * 2
1101
+
1102
+ # convert back to per-thread values
1103
+ s = wp.untile()
1104
+
1105
+ print(s)
1106
+
1107
+
1108
+ wp.launch(compute, dim=16, inputs=[], block_dim=16)
1109
+
1110
+ Prints:
1111
+
1112
+ .. code-block:: text
1113
+
1114
+ 0
1115
+ 2
1116
+ 4
1117
+ 6
1118
+ 8
1119
+ ...
1120
+
1121
+ """
1122
+ ...
1123
+
1124
+
1125
+ @over
1126
+ def tile_extract(a: Tile, i: int32, j: int32) -> Scalar:
1127
+ """Extracts a single element from the tile and returns it as a scalar type.
1128
+
1129
+ This function will extract an element from the tile and broadcast its value to all threads in the block.
1130
+
1131
+ Note that this may incur additional synchronization if the source tile is a register tile.
1132
+
1133
+ :param a: Tile to extract the element from
1134
+ :param i: Coordinate of element on first dimension
1135
+ :param j: Coordinate of element on the second dimension
1136
+ :returns: The value of the element at the specified tile location, with the same type as the input tile's per-element dtype
1137
+ """
1138
+ ...
1139
+
1140
+
1141
+ @over
1142
+ def tile_transpose(a: Tile) -> Tile:
1143
+ """Transpose a tile.
1144
+
1145
+ For shared memory tiles this operation will alias the input tile, register tiles will first be transferred to shared memory before transposition.
1146
+
1147
+ :param a: Tile to transpose with ``shape=(M,N)``
1148
+ :returns: Tile with ``shape=(N,M)``
1149
+ """
1150
+ ...
1151
+
1152
+
1153
+ @over
1154
+ def tile_broadcast(a: Tile, m: int32, n: int32) -> Tile:
1155
+ """Broadcast a tile.
1156
+
1157
+ This method will attempt to broadcast the input tile ``a`` to the destination shape (m, n), broadcasting follows NumPy broadcast rules.
1158
+
1159
+ :param a: Tile to broadcast
1160
+ :returns: Tile with broadcast ``shape=(m, n)``
1161
+ """
1162
+ ...
1163
+
1164
+
1165
+ @over
1166
+ def tile_sum(a: Tile) -> Tile:
1167
+ """Cooperatively compute the sum of the tile elements using all threads in the block.
1168
+
1169
+ :param a: The tile to compute the sum of
1170
+ :returns: A single-element tile with dimensions of (1,1) holding the sum
1171
+
1172
+ Example:
1173
+
1174
+ .. code-block:: python
1175
+
1176
+ @wp.kernel
1177
+ def compute():
1178
+ t = wp.tile_ones(dtype=float, m=16, n=16)
1179
+ s = wp.tile_sum(t)
1180
+
1181
+ print(t)
1182
+
1183
+
1184
+ wp.launch(compute, dim=[64], inputs=[])
1185
+
1186
+ Prints:
1187
+
1188
+ .. code-block:: text
1189
+
1190
+ tile(m=1, n=1, storage=register) = [[256]]
1191
+
1192
+
1193
+ """
1194
+ ...
1195
+
1196
+
1197
+ @over
1198
+ def tile_min(a: Tile) -> Tile:
1199
+ """Cooperatively compute the minimum of the tile elements using all threads in the block.
1200
+
1201
+ :param a: The tile to compute the minimum of
1202
+ :returns: A single-element tile with dimensions of (1,1) holding the minimum value
1203
+
1204
+ Example:
1205
+
1206
+ .. code-block:: python
1207
+
1208
+ @wp.kernel
1209
+ def compute():
1210
+ t = wp.tile_arange(start=--10, stop=10, dtype=float)
1211
+ s = wp.tile_min(t)
1212
+
1213
+ print(t)
1214
+
1215
+
1216
+ wp.launch(compute, dim=[64], inputs=[])
1217
+
1218
+ Prints:
1219
+
1220
+ .. code-block:: text
1221
+
1222
+ tile(m=1, n=1, storage=register) = [[-10]]
1223
+
1224
+
1225
+ """
1226
+ ...
1227
+
1228
+
1229
+ @over
1230
+ def tile_max(a: Tile) -> Tile:
1231
+ """Cooperatively compute the maximum of the tile elements using all threads in the block.
1232
+
1233
+ :param a: The tile to compute the maximum from
1234
+ :returns: A single-element tile with dimensions of (1,1) holding the maximum value
1235
+
1236
+ Example:
1237
+
1238
+ .. code-block:: python
1239
+
1240
+ @wp.kernel
1241
+ def compute():
1242
+ t = wp.tile_arange(start=--10, stop=10, dtype=float)
1243
+ s = wp.tile_min(t)
1244
+
1245
+ print(t)
1246
+
1247
+
1248
+ wp.launch(compute, dim=[64], inputs=[])
1249
+
1250
+ Prints:
1251
+
1252
+ .. code-block:: text
1253
+
1254
+ tile(m=1, n=1, storage=register) = [[10]]
1255
+
1256
+
1257
+ """
1258
+ ...
1259
+
1260
+
1261
+ @over
1262
+ def tile_reduce(op: Callable, a: Any) -> Tile:
1263
+ """Apply a custom reduction operator across the tile.
1264
+
1265
+ This function cooperatively performs a reduction using the provided operator across the tile.
1266
+
1267
+ :param op: A callable function that accepts two arguments and returns one argument, may be a user function or builtin
1268
+ :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
1269
+ :returns: A single-element tile with ``shape=(1,1)`` with the same datatype as the input tile.
1270
+
1271
+ Example:
1272
+
1273
+ .. code-block:: python
1274
+
1275
+ @wp.kernel
1276
+ def factorial():
1277
+ t = wp.tile_arange(1, 10, dtype=int)
1278
+ s = wp.tile_reduce(wp.mul, t)
1279
+
1280
+ print(s)
1281
+
1282
+
1283
+ wp.launch(factorial, dim=[16], inputs=[], block_dim=16)
1284
+
1285
+ Prints:
1286
+
1287
+ .. code-block:: text
1288
+
1289
+ tile(m=1, n=1, storage=register) = [[362880]]
1290
+
1291
+ """
1292
+ ...
1293
+
1294
+
1295
+ @over
1296
+ def tile_map(op: Callable, a: Any) -> Tile:
1297
+ """Apply a unary function onto the tile.
1298
+
1299
+ This function cooperatively applies a unary function to each element of the tile using all threads in the block.
1300
+
1301
+ :param op: A callable function that accepts one argument and returns one argument, may be a user function or builtin
1302
+ :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
1303
+ :returns: A tile with the same dimensions and datatype as the input tile.
1304
+
1305
+ Example:
1306
+
1307
+ .. code-block:: python
1308
+
1309
+ @wp.kernel
1310
+ def compute():
1311
+ t = wp.tile_arange(0.0, 1.0, 0.1, dtype=float)
1312
+ s = wp.tile_map(wp.sin, t)
1313
+
1314
+ print(s)
1315
+
1316
+
1317
+ wp.launch(compute, dim=[16], inputs=[])
1318
+
1319
+ Prints:
1320
+
1321
+ .. code-block:: text
1322
+
1323
+ tile(m=1, n=10, storage=register) = [[0 0.0998334 0.198669 0.29552 ...]]
1324
+
1325
+ """
1326
+ ...
1327
+
1328
+
1329
+ @over
1330
+ def tile_map(op: Callable, a: Any, b: Any) -> Tile:
1331
+ """Apply a binary function onto the tile.
1332
+
1333
+ This function cooperatively applies a binary function to each element of the tiles using all threads in the block.
1334
+ Both input tiles must have the same dimensions and datatype.
1335
+
1336
+ :param op: A callable function that accepts two arguments and returns one argument, all of the same type, may be a user function or builtin
1337
+ :param a: The first input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
1338
+ :param b: The second input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
1339
+ :returns: A tile with the same dimensions and datatype as the input tiles.
1340
+
1341
+ Example:
1342
+
1343
+ .. code-block:: python
1344
+
1345
+ @wp.kernel
1346
+ def compute():
1347
+ a = wp.tile_arange(0.0, 1.0, 0.1, dtype=float)
1348
+ b = wp.tile_ones(m=1, n=10, dtype=float)
1349
+
1350
+ s = wp.tile_map(wp.add, a, b)
1351
+
1352
+ print(s)
1353
+
1354
+
1355
+ wp.launch(compute, dim=[16], inputs=[])
1356
+
1357
+ Prints:
1358
+
1359
+ .. code-block:: text
1360
+
1361
+ tile(m=1, n=10, storage=register) = [[1 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9]]
1362
+ """
1363
+ ...
1364
+
1365
+
897
1366
  @over
898
1367
  def mlp(
899
1368
  weights: Array[float32],
@@ -1162,6 +1631,12 @@ def closest_point_edge_edge(p1: vec3f, q1: vec3f, p2: vec3f, q2: vec3f, epsilon:
1162
1631
  ...
1163
1632
 
1164
1633
 
1634
+ @over
1635
+ def reversed(range: range_t) -> range_t:
1636
+ """Returns the range in reversed order."""
1637
+ ...
1638
+
1639
+
1165
1640
  @over
1166
1641
  def volume_sample(id: uint64, uvw: vec3f, sampling_mode: int32, dtype: Any) -> Any:
1167
1642
  """Sample the volume of type `dtype` given by ``id`` at the volume local-space point ``uvw``.
@@ -1376,7 +1851,7 @@ def randf(state: uint32, low: float32, high: float32) -> float:
1376
1851
 
1377
1852
  @over
1378
1853
  def randn(state: uint32) -> float:
1379
- """Sample a normal distribution."""
1854
+ """Sample a normal (Gaussian) distribution of mean 0 and variance 1."""
1380
1855
  ...
1381
1856
 
1382
1857
 
@@ -2091,6 +2566,12 @@ def add(a: Transformation[Scalar], b: Transformation[Scalar]) -> Transformation[
2091
2566
  ...
2092
2567
 
2093
2568
 
2569
+ @over
2570
+ def add(a: Tile, b: Tile) -> Tile:
2571
+ """Add each element of two tiles together"""
2572
+ ...
2573
+
2574
+
2094
2575
  @over
2095
2576
  def sub(a: Scalar, b: Scalar) -> Scalar:
2096
2577
  """ """
@@ -2241,6 +2722,18 @@ def mul(a: Transformation[Scalar], b: Scalar) -> Transformation[Scalar]:
2241
2722
  ...
2242
2723
 
2243
2724
 
2725
+ @over
2726
+ def mul(x: Tile, y: Scalar) -> Tile:
2727
+ """Multiply each element of a tile by a scalar"""
2728
+ ...
2729
+
2730
+
2731
+ @over
2732
+ def mul(x: Scalar, y: Tile) -> Tile:
2733
+ """Multiply each element of a tile by a scalar"""
2734
+ ...
2735
+
2736
+
2244
2737
  @over
2245
2738
  def mod(a: Scalar, b: Scalar) -> Scalar:
2246
2739
  """Modulo operation using truncated division."""
@@ -2349,6 +2842,12 @@ def neg(x: Matrix[Any, Any, Scalar]) -> Matrix[Any, Any, Scalar]:
2349
2842
  ...
2350
2843
 
2351
2844
 
2845
+ @over
2846
+ def neg(x: Tile) -> Tile:
2847
+ """Negate each element of a tile"""
2848
+ ...
2849
+
2850
+
2352
2851
  @over
2353
2852
  def unot(a: bool) -> bool:
2354
2853
  """ """
@@ -2409,6 +2908,72 @@ def unot(a: Array[Any]) -> bool:
2409
2908
  ...
2410
2909
 
2411
2910
 
2911
+ @over
2912
+ def tile_matmul(a: Tile, b: Tile, out: Tile) -> Tile:
2913
+ """Computes the matrix product and accumulates ``out += a*b``.
2914
+
2915
+ Supported datatypes are:
2916
+ * fp16, fp32, fp64 (real)
2917
+ * vec2h, vec2f, vec2d (complex)
2918
+
2919
+ All input and output tiles must have the same datatype. Tile data will be automatically be migrated
2920
+ to shared memory if necessary and will use TensorCore operations when available.
2921
+
2922
+ :param a: A tile with ``shape=(M, K)``
2923
+ :param b: A tile with ``shape=(K, N)``
2924
+ :param out: A tile with ``shape=(M, N)``
2925
+
2926
+ """
2927
+ ...
2928
+
2929
+
2930
+ @over
2931
+ def tile_matmul(a: Tile, b: Tile) -> Tile:
2932
+ """Computes the matrix product ``out = a*b``.
2933
+
2934
+ Supported datatypes are:
2935
+ * fp16, fp32, fp64 (real)
2936
+ * vec2h, vec2f, vec2d (complex)
2937
+
2938
+ Both input tiles must have the same datatype. Tile data will be automatically be migrated
2939
+ to shared memory if necessary and will use TensorCore operations when available.
2940
+
2941
+ :param a: A tile with ``shape=(M, K)``
2942
+ :param b: A tile with ``shape=(K, N)``
2943
+ :returns: A tile with ``shape=(M, N)``
2944
+
2945
+ """
2946
+ ...
2947
+
2948
+
2949
+ @over
2950
+ def tile_fft(inout: Tile) -> Tile:
2951
+ """Compute the forward FFT along the second dimension of a 2D tile of data.
2952
+
2953
+ This function cooperatively computes the forward FFT on a tile of data inplace, treating each row individually.
2954
+
2955
+ Supported datatypes are:
2956
+ * vec2f, vec2d
2957
+
2958
+ :param inout: The input/output tile
2959
+ """
2960
+ ...
2961
+
2962
+
2963
+ @over
2964
+ def tile_ifft(inout: Tile) -> Tile:
2965
+ """Compute the inverse FFT along the second dimension of a 2D tile of data.
2966
+
2967
+ This function cooperatively computes the inverse FFT on a tile of data inplace, treating each row individually.
2968
+
2969
+ Supported datatypes are:
2970
+ * vec2f, vec2d
2971
+
2972
+ :param inout: The input/output tile
2973
+ """
2974
+ ...
2975
+
2976
+
2412
2977
  @over
2413
2978
  def static(expr: Any) -> Any:
2414
2979
  """Evaluates a static Python expression and replaces it with its result.
warp/tape.py CHANGED
@@ -15,7 +15,7 @@ class Tape:
15
15
  """
16
16
  Record kernel launches within a Tape scope to enable automatic differentiation.
17
17
  Gradients can be computed after the operations have been recorded on the tape via
18
- ``tape.backward()``.
18
+ :meth:`Tape.backward()`.
19
19
 
20
20
  Example
21
21
  -------
@@ -131,6 +131,7 @@ class Tape:
131
131
  inputs = launch[3]
132
132
  outputs = launch[4]
133
133
  device = launch[5]
134
+ block_dim = launch[6]
134
135
 
135
136
  adj_inputs = []
136
137
  adj_outputs = []
@@ -153,13 +154,14 @@ class Tape:
153
154
  device=device,
154
155
  adjoint=True,
155
156
  max_blocks=max_blocks,
157
+ block_dim=block_dim,
156
158
  )
157
159
 
158
160
  # record a kernel launch on the tape
159
- def record_launch(self, kernel, dim, max_blocks, inputs, outputs, device, metadata=None):
161
+ def record_launch(self, kernel, dim, max_blocks, inputs, outputs, device, block_dim=0, metadata=None):
160
162
  if metadata is None:
161
163
  metadata = {}
162
- self.launches.append([kernel, dim, max_blocks, inputs, outputs, device, metadata])
164
+ self.launches.append([kernel, dim, max_blocks, inputs, outputs, device, block_dim, metadata])
163
165
 
164
166
  def record_func(self, backward, arrays):
165
167
  """
@@ -614,7 +616,9 @@ class ArrayStatsVisitor(TapeVisitor):
614
616
  self.array_grad_stats.insert(0, grad_stats)
615
617
 
616
618
 
617
- Launch = namedtuple("Launch", ["id", "kernel", "dim", "max_blocks", "inputs", "outputs", "device", "metadata"])
619
+ Launch = namedtuple(
620
+ "Launch", ["id", "kernel", "dim", "max_blocks", "inputs", "outputs", "device", "block_dim", "metadata"]
621
+ )
618
622
  RepeatedSequence = namedtuple("RepeatedSequence", ["start", "end", "repetitions"])
619
623
 
620
624
 
@@ -645,8 +649,8 @@ def visit_tape(
645
649
  def get_launch_id(launch):
646
650
  kernel = launch[0]
647
651
  suffix = ""
648
- if len(launch) > 6:
649
- metadata = launch[6]
652
+ if len(launch) > 7:
653
+ metadata = launch[7]
650
654
  # calling function helps to identify unique launches
651
655
  if "caller" in metadata:
652
656
  caller = metadata["caller"]
@@ -680,7 +684,8 @@ def visit_tape(
680
684
  inputs=launch[3],
681
685
  outputs=launch[4],
682
686
  device=launch[5],
683
- metadata=launch[6] if len(launch) > 6 else {},
687
+ block_dim=launch[6],
688
+ metadata=launch[7] if len(launch) > 7 else {},
684
689
  )
685
690
  for launch in kernel_launches
686
691
  ]
Binary file