warp-lang 1.4.1__py3-none-macosx_10_13_universal2.whl → 1.5.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (164) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1920 -111
  8. warp/codegen.py +186 -62
  9. warp/config.py +2 -2
  10. warp/context.py +322 -73
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/core/example_dem.py +2 -1
  17. warp/examples/core/example_mesh_intersect.py +3 -3
  18. warp/examples/fem/example_adaptive_grid.py +37 -10
  19. warp/examples/fem/example_apic_fluid.py +3 -2
  20. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  21. warp/examples/fem/example_deformed_geometry.py +1 -1
  22. warp/examples/fem/example_diffusion_3d.py +47 -4
  23. warp/examples/fem/example_distortion_energy.py +220 -0
  24. warp/examples/fem/example_magnetostatics.py +127 -85
  25. warp/examples/fem/example_nonconforming_contact.py +5 -5
  26. warp/examples/fem/example_stokes.py +3 -1
  27. warp/examples/fem/example_streamlines.py +12 -19
  28. warp/examples/fem/utils.py +38 -15
  29. warp/examples/optim/example_walker.py +2 -2
  30. warp/examples/sim/example_cloth.py +2 -25
  31. warp/examples/sim/example_jacobian_ik.py +6 -2
  32. warp/examples/sim/example_quadruped.py +2 -1
  33. warp/examples/tile/example_tile_convolution.py +58 -0
  34. warp/examples/tile/example_tile_fft.py +47 -0
  35. warp/examples/tile/example_tile_filtering.py +105 -0
  36. warp/examples/tile/example_tile_matmul.py +79 -0
  37. warp/examples/tile/example_tile_mlp.py +375 -0
  38. warp/fem/__init__.py +8 -0
  39. warp/fem/cache.py +16 -12
  40. warp/fem/dirichlet.py +1 -1
  41. warp/fem/domain.py +44 -1
  42. warp/fem/field/__init__.py +1 -2
  43. warp/fem/field/field.py +31 -19
  44. warp/fem/field/nodal_field.py +101 -49
  45. warp/fem/field/virtual.py +794 -0
  46. warp/fem/geometry/__init__.py +2 -2
  47. warp/fem/geometry/deformed_geometry.py +3 -105
  48. warp/fem/geometry/element.py +13 -0
  49. warp/fem/geometry/geometry.py +165 -5
  50. warp/fem/geometry/grid_2d.py +3 -6
  51. warp/fem/geometry/grid_3d.py +31 -28
  52. warp/fem/geometry/hexmesh.py +3 -46
  53. warp/fem/geometry/nanogrid.py +3 -2
  54. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  55. warp/fem/geometry/tetmesh.py +2 -43
  56. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  57. warp/fem/integrate.py +683 -261
  58. warp/fem/linalg.py +404 -0
  59. warp/fem/operator.py +101 -18
  60. warp/fem/polynomial.py +5 -5
  61. warp/fem/quadrature/quadrature.py +45 -21
  62. warp/fem/space/__init__.py +45 -11
  63. warp/fem/space/basis_function_space.py +451 -0
  64. warp/fem/space/basis_space.py +58 -11
  65. warp/fem/space/function_space.py +146 -5
  66. warp/fem/space/grid_2d_function_space.py +80 -66
  67. warp/fem/space/grid_3d_function_space.py +113 -68
  68. warp/fem/space/hexmesh_function_space.py +96 -108
  69. warp/fem/space/nanogrid_function_space.py +62 -110
  70. warp/fem/space/quadmesh_function_space.py +208 -0
  71. warp/fem/space/shape/__init__.py +45 -7
  72. warp/fem/space/shape/cube_shape_function.py +328 -54
  73. warp/fem/space/shape/shape_function.py +10 -1
  74. warp/fem/space/shape/square_shape_function.py +328 -60
  75. warp/fem/space/shape/tet_shape_function.py +269 -19
  76. warp/fem/space/shape/triangle_shape_function.py +238 -19
  77. warp/fem/space/tetmesh_function_space.py +69 -37
  78. warp/fem/space/topology.py +38 -0
  79. warp/fem/space/trimesh_function_space.py +179 -0
  80. warp/fem/utils.py +6 -331
  81. warp/jax_experimental.py +3 -1
  82. warp/native/array.h +55 -40
  83. warp/native/builtin.h +124 -43
  84. warp/native/bvh.h +4 -0
  85. warp/native/coloring.cpp +600 -0
  86. warp/native/cuda_util.cpp +14 -0
  87. warp/native/cuda_util.h +2 -1
  88. warp/native/fabric.h +8 -0
  89. warp/native/hashgrid.h +4 -0
  90. warp/native/marching.cu +8 -0
  91. warp/native/mat.h +14 -3
  92. warp/native/mathdx.cpp +59 -0
  93. warp/native/mesh.h +4 -0
  94. warp/native/range.h +13 -1
  95. warp/native/reduce.cpp +9 -1
  96. warp/native/reduce.cu +7 -0
  97. warp/native/runlength_encode.cpp +9 -1
  98. warp/native/runlength_encode.cu +7 -1
  99. warp/native/scan.cpp +8 -0
  100. warp/native/scan.cu +8 -0
  101. warp/native/scan.h +8 -1
  102. warp/native/sparse.cpp +8 -0
  103. warp/native/sparse.cu +8 -0
  104. warp/native/temp_buffer.h +7 -0
  105. warp/native/tile.h +1857 -0
  106. warp/native/tile_gemm.h +341 -0
  107. warp/native/tile_reduce.h +210 -0
  108. warp/native/volume_builder.cu +8 -0
  109. warp/native/volume_builder.h +8 -0
  110. warp/native/warp.cpp +10 -2
  111. warp/native/warp.cu +369 -15
  112. warp/native/warp.h +12 -2
  113. warp/optim/adam.py +39 -4
  114. warp/paddle.py +29 -12
  115. warp/render/render_opengl.py +137 -65
  116. warp/sim/graph_coloring.py +292 -0
  117. warp/sim/integrator_euler.py +4 -2
  118. warp/sim/integrator_featherstone.py +115 -44
  119. warp/sim/integrator_vbd.py +6 -0
  120. warp/sim/model.py +90 -17
  121. warp/stubs.py +651 -85
  122. warp/tape.py +12 -7
  123. warp/tests/assets/pixel.npy +0 -0
  124. warp/tests/aux_test_instancing_gc.py +18 -0
  125. warp/tests/test_array.py +207 -48
  126. warp/tests/test_closest_point_edge_edge.py +8 -8
  127. warp/tests/test_codegen.py +120 -1
  128. warp/tests/test_codegen_instancing.py +30 -0
  129. warp/tests/test_collision.py +110 -0
  130. warp/tests/test_coloring.py +241 -0
  131. warp/tests/test_context.py +34 -0
  132. warp/tests/test_examples.py +18 -4
  133. warp/tests/test_fabricarray.py +33 -0
  134. warp/tests/test_fem.py +453 -113
  135. warp/tests/test_func.py +48 -1
  136. warp/tests/test_generics.py +52 -0
  137. warp/tests/test_iter.py +68 -0
  138. warp/tests/test_mat_scalar_ops.py +1 -1
  139. warp/tests/test_mesh_query_point.py +5 -4
  140. warp/tests/test_module_hashing.py +23 -0
  141. warp/tests/test_paddle.py +27 -87
  142. warp/tests/test_print.py +191 -1
  143. warp/tests/test_spatial.py +1 -1
  144. warp/tests/test_tile.py +700 -0
  145. warp/tests/test_tile_mathdx.py +144 -0
  146. warp/tests/test_tile_mlp.py +383 -0
  147. warp/tests/test_tile_reduce.py +374 -0
  148. warp/tests/test_tile_shared_memory.py +190 -0
  149. warp/tests/test_vbd.py +12 -20
  150. warp/tests/test_volume.py +43 -0
  151. warp/tests/unittest_suites.py +23 -2
  152. warp/tests/unittest_utils.py +4 -0
  153. warp/types.py +339 -73
  154. warp/utils.py +22 -1
  155. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
  156. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/RECORD +159 -132
  157. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
  158. warp/fem/field/test.py +0 -180
  159. warp/fem/field/trial.py +0 -183
  160. warp/fem/space/collocated_function_space.py +0 -102
  161. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  162. warp/fem/space/trimesh_2d_function_space.py +0 -153
  163. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
  164. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
warp/stubs.py CHANGED
@@ -11,9 +11,6 @@ Length = TypeVar("Length", bound=int)
11
11
  Rows = TypeVar("Rows", bound=int)
12
12
  Cols = TypeVar("Cols", bound=int)
13
13
  DType = TypeVar("DType")
14
- Int = TypeVar("Int")
15
- Float = TypeVar("Float")
16
- Scalar = TypeVar("Scalar")
17
14
  Vector = Generic[Length, Scalar]
18
15
  Matrix = Generic[Rows, Cols, Scalar]
19
16
  Quaternion = Generic[Float]
@@ -39,6 +36,8 @@ from warp.types import transform, transformh, transformf, transformd
39
36
  from warp.types import spatial_vector, spatial_vectorh, spatial_vectorf, spatial_vectord
40
37
  from warp.types import spatial_matrix, spatial_matrixh, spatial_matrixf, spatial_matrixd
41
38
 
39
+ from warp.types import Int, Float, Scalar
40
+
42
41
  from warp.types import Bvh, Mesh, HashGrid, Volume, MarchingCubes
43
42
  from warp.types import BvhQuery, HashGridQuery, MeshQueryAABB, MeshQueryPoint, MeshQueryRay
44
43
 
@@ -67,6 +66,7 @@ from warp.context import (
67
66
  copy,
68
67
  from_numpy,
69
68
  launch,
69
+ launch_tiled,
70
70
  synchronize,
71
71
  force_load,
72
72
  load_module,
@@ -785,7 +785,8 @@ def transform_point(mat: Matrix[4, 4, Float], point: Vector[3, Float]) -> Vector
785
785
  """Apply the transform to a point ``point`` treating the homogeneous coordinate as w=1.
786
786
 
787
787
  The transformation is applied treating ``point`` as a column vector, e.g.: ``y = mat*point``.
788
- Note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = point^T*mat^T``.
788
+
789
+ This is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = point^T*mat^T``.
789
790
  If the transform is coming from a library that uses row-vectors, then users should transpose the transformation
790
791
  matrix before calling this method.
791
792
  """
@@ -802,8 +803,9 @@ def transform_vector(xform: Transformation[Float], vec: Vector[3, Float]) -> Vec
802
803
  def transform_vector(mat: Matrix[4, 4, Float], vec: Vector[3, Float]) -> Vector[3, Float]:
803
804
  """Apply the transform to a vector ``vec`` treating the homogeneous coordinate as w=0.
804
805
 
805
- The transformation is applied treating ``vec`` as a column vector, e.g.: ``y = mat*vec``
806
- note this is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = vec^T*mat^T``.
806
+ The transformation is applied treating ``vec`` as a column vector, e.g.: ``y = mat*vec``.
807
+
808
+ This is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = vec^T*mat^T``.
807
809
  If the transform is coming from a library that uses row-vectors, then users should transpose the transformation
808
810
  matrix before calling this method.
809
811
  """
@@ -892,6 +894,475 @@ def spatial_mass(
892
894
  ...
893
895
 
894
896
 
897
+ @over
898
+ def tile_zeros(m: int32, n: int32, dtype: Any, storage: str) -> Tile:
899
+ """Allocates a tile of zero-initialized items.
900
+
901
+ :param m: Size of the first dimension of the output tile
902
+ :param n: Size of the second dimension of the output tile
903
+ :param dtype: Datatype of output tile's elements
904
+ :param storage: The storage location for the tile: ``"register"`` for registers
905
+ (default) or ``"shared"`` for shared memory.
906
+ :returns: A zero-initialized tile with ``shape=(m,n)`` and the specified datatype
907
+ """
908
+ ...
909
+
910
+
911
+ @over
912
+ def tile_ones(m: int32, n: int32, dtype: Any, storage: str) -> Tile:
913
+ """Allocates a tile of one-initialized items.
914
+
915
+ :param m: Size of the first dimension of the output tile
916
+ :param n: Size of the second dimension of the output tile
917
+ :param dtype: Datatype of output tile's elements
918
+ :param storage: The storage location for the tile: ``"register"`` for registers
919
+ (default) or ``"shared"`` for shared memory.
920
+ :returns: A one-initialized tile with ``shape=(m,n)`` and the specified dtype
921
+ """
922
+ ...
923
+
924
+
925
+ @over
926
+ def tile_arange(*args: Scalar, dtype: Any, storage: str) -> Tile:
927
+ """Generates a tile of linearly spaced elements.
928
+
929
+ :param args: Variable-length positional arguments, interpreted as:
930
+
931
+ - ``(stop,)``: Generates values from ``0`` to ``stop - 1``
932
+ - ``(start, stop)``: Generates values from ``start`` to ``stop - 1``
933
+ - ``(start, stop, step)``: Generates values from ``start`` to ``stop - 1`` with a step size
934
+
935
+ :param dtype: Datatype of output tile's elements (optional, default: int)
936
+ :param storage: The storage location for the tile: ``"register"`` for registers
937
+ (default) or ``"shared"`` for shared memory.
938
+ :returns: A tile with ``shape=(1,n)`` with linearly spaced elements of specified dtype
939
+ """
940
+ ...
941
+
942
+
943
+ @over
944
+ def tile_load(a: Array[Any], i: int32, n: int32, storage: str) -> Tile:
945
+ """Loads a 1D tile from a global memory array.
946
+
947
+ This method will cooperatively load a tile from global memory using all threads in the block.
948
+
949
+ :param a: The source array in global memory
950
+ :param i: Offset in the source array measured in multiples of ``n``, i.e.: ``offset=i*n``
951
+ :param n: The number of elements in the tile
952
+ :param storage: The storage location for the tile: ``"register"`` for registers
953
+ (default) or ``"shared"`` for shared memory.
954
+ :returns: A tile with ``shape=(1,n)`` and dtype the same as the source array
955
+ """
956
+ ...
957
+
958
+
959
+ @over
960
+ def tile_load(a: Array[Any], i: int32, j: int32, m: int32, n: int32, storage: str) -> Tile:
961
+ """Loads a 2D tile from a global memory array.
962
+
963
+ This method will cooperatively load a tile from global memory using all threads in the block.
964
+
965
+ :param a: The source array in global memory
966
+ :param i: Offset in the source array measured in multiples of ``m``, i.e.: ``row=i*m``
967
+ :param j: Offset in the source array measured in multiples of ``n``, i.e.; ``col=j*n``
968
+ :param m: The size of the tile's first dimension
969
+ :param n: The size of the tile's second dimension
970
+ :param storage: The storage location for the tile: ``"register"`` for registers
971
+ (default) or ``"shared"`` for shared memory.
972
+ :returns: A tile with ``shape=(m,n)`` and dtype the same as the source array
973
+ """
974
+ ...
975
+
976
+
977
+ @over
978
+ def tile_store(a: Array[Any], i: int32, t: Any):
979
+ """Stores a 1D tile to a global memory array.
980
+
981
+ This method will cooperatively store a tile to global memory using all threads in the block.
982
+
983
+ :param a: The destination array in global memory
984
+ :param i: Offset in the destination array measured in multiples of ``n``, i.e.: ``offset=i*n``
985
+ :param t: The source tile to store data from, must have the same dtype as the destination array
986
+ """
987
+ ...
988
+
989
+
990
+ @over
991
+ def tile_store(a: Array[Any], i: int32, j: int32, t: Any):
992
+ """Stores a tile to a global memory array.
993
+
994
+ This method will cooperatively store a tile to global memory using all threads in the block.
995
+
996
+ :param a: The destination array in global memory
997
+ :param i: Offset in the destination array measured in multiples of ``m``, i.e.: ``row=i*m``
998
+ :param j: Offset in the destination array measured in multiples of ``n``, i.e.; ``col=j*n``
999
+ :param t: The source tile to store data from, must have the same dtype as the destination array
1000
+ """
1001
+ ...
1002
+
1003
+
1004
+ @over
1005
+ def tile_atomic_add(a: Array[Any], x: int32, y: int32, t: Any) -> Tile:
1006
+ """Atomically add a tile to the array `a`, each element will be updated atomically.
1007
+
1008
+ :param a: Array in global memory, should have the same ``dtype`` as the input tile
1009
+ :param x: Offset in the destination array measured in multiples of ``m``, i.e.: ``i=x*M`` where ``M`` is the first tile dimension
1010
+ :param y: Offset in the destination array measured in multiples of ``n``, i.e.: ``j=y*N`` where ``N`` is the second tile dimension
1011
+ :param t: Source tile to add to the destination array
1012
+ :returns: A tile with the same dimensions and type as the source tile, holding the original value of the destination elements
1013
+ """
1014
+ ...
1015
+
1016
+
1017
+ @over
1018
+ def tile_view(t: Tile, i: int32, j: int32, m: int32, n: int32) -> Tile:
1019
+ """Return a subrange of a given tile from coordinates (i,j) to (i+m, j+n).
1020
+
1021
+ :param t: Input tile to extract a subrange from
1022
+ :param i: Offset in the source tile along the first dimension
1023
+ :param j: Offset in the source tile along the second dimensions
1024
+ :param m: Size of the subrange to return along the first dimension
1025
+ :param n: Size of the subrange to return along the second dimension
1026
+ :returns: A tile with dimensions (m,n) and the same datatype as the input tile
1027
+ """
1028
+ ...
1029
+
1030
+
1031
+ @over
1032
+ def tile_assign(dst: Tile, i: int32, j: int32, src: Tile):
1033
+ """Assign a tile to a subrange of a destination tile at coordinates (i,j).
1034
+
1035
+ :param t: The destination tile to assign to
1036
+ :param i: Offset in the source tile along the first dimension
1037
+ :param j: Offset in the source tile along the second dimensions
1038
+ :param src: The source tile to read values from
1039
+ """
1040
+ ...
1041
+
1042
+
1043
+ @over
1044
+ def tile(x: Any) -> Tile:
1045
+ """Constructs a new Tile from per-thread kernel values.
1046
+
1047
+ This function converts values computed using scalar kernel code to a tile representation for input into collective operations.
1048
+
1049
+ * If the input value is a scalar, then the resulting tile has ``shape=(1, block_dim)``
1050
+ * If the input value is a vector, then the resulting tile has ``shape=(length(vector), block_dim)``
1051
+
1052
+ :param x: A per-thread local value, e.g.: scalar, vector, or matrix.
1053
+ :returns: A tile with first dimension according to the value type length and a second dimension equal to ``block_dim``
1054
+
1055
+ This example shows how to create a linear sequence from thread variables:
1056
+
1057
+ .. code-block:: python
1058
+
1059
+ @wp.kernel
1060
+ def compute():
1061
+ i = wp.tid()
1062
+ t = wp.tile(i * 2)
1063
+ print(t)
1064
+
1065
+
1066
+ wp.launch(compute, dim=16, inputs=[], block_dim=16)
1067
+
1068
+ Prints:
1069
+
1070
+ .. code-block:: text
1071
+
1072
+ tile(m=1, n=16, storage=register) = [[0 2 4 6 8 ...]]
1073
+
1074
+
1075
+ """
1076
+ ...
1077
+
1078
+
1079
+ @over
1080
+ def untile(a: Any) -> Scalar:
1081
+ """Convert a Tile back to per-thread values.
1082
+
1083
+ This function converts a block-wide tile back to per-thread values.
1084
+
1085
+ * If the input tile is 1-dimensional then the resulting value will be a per-thread scalar
1086
+ * If the input tile is 2-dimensional then the resulting value will be a per-thread vector of length M
1087
+
1088
+ :param a: A tile with dimensions ``shape=(M, block_dim)``
1089
+ :returns: A single value per-thread with the same dtype as the tile
1090
+
1091
+ This example shows how to create a linear sequence from thread variables:
1092
+
1093
+ .. code-block:: python
1094
+
1095
+ @wp.kernel
1096
+ def compute():
1097
+ i = wp.tid()
1098
+
1099
+ # create block-wide tile
1100
+ t = wp.tile(i) * 2
1101
+
1102
+ # convert back to per-thread values
1103
+ s = wp.untile()
1104
+
1105
+ print(s)
1106
+
1107
+
1108
+ wp.launch(compute, dim=16, inputs=[], block_dim=16)
1109
+
1110
+ Prints:
1111
+
1112
+ .. code-block:: text
1113
+
1114
+ 0
1115
+ 2
1116
+ 4
1117
+ 6
1118
+ 8
1119
+ ...
1120
+
1121
+ """
1122
+ ...
1123
+
1124
+
1125
+ @over
1126
+ def tile_extract(a: Tile, i: int32, j: int32) -> Scalar:
1127
+ """Extracts a single element from the tile and returns it as a scalar type.
1128
+
1129
+ This function will extract an element from the tile and broadcast its value to all threads in the block.
1130
+
1131
+ Note that this may incur additional synchronization if the source tile is a register tile.
1132
+
1133
+ :param a: Tile to extract the element from
1134
+ :param i: Coordinate of element on first dimension
1135
+ :param j: Coordinate of element on the second dimension
1136
+ :returns: The value of the element at the specified tile location, with the same type as the input tile's per-element dtype
1137
+ """
1138
+ ...
1139
+
1140
+
1141
+ @over
1142
+ def tile_transpose(a: Tile) -> Tile:
1143
+ """Transpose a tile.
1144
+
1145
+ For shared memory tiles this operation will alias the input tile, register tiles will first be transferred to shared memory before transposition.
1146
+
1147
+ :param a: Tile to transpose with ``shape=(M,N)``
1148
+ :returns: Tile with ``shape=(N,M)``
1149
+ """
1150
+ ...
1151
+
1152
+
1153
+ @over
1154
+ def tile_broadcast(a: Tile, m: int32, n: int32) -> Tile:
1155
+ """Broadcast a tile.
1156
+
1157
+ This method will attempt to broadcast the input tile ``a`` to the destination shape (m, n), broadcasting follows NumPy broadcast rules.
1158
+
1159
+ :param a: Tile to broadcast
1160
+ :returns: Tile with broadcast ``shape=(m, n)``
1161
+ """
1162
+ ...
1163
+
1164
+
1165
+ @over
1166
+ def tile_sum(a: Tile) -> Tile:
1167
+ """Cooperatively compute the sum of the tile elements using all threads in the block.
1168
+
1169
+ :param a: The tile to compute the sum of
1170
+ :returns: A single-element tile with dimensions of (1,1) holding the sum
1171
+
1172
+ Example:
1173
+
1174
+ .. code-block:: python
1175
+
1176
+ @wp.kernel
1177
+ def compute():
1178
+ t = wp.tile_ones(dtype=float, m=16, n=16)
1179
+ s = wp.tile_sum(t)
1180
+
1181
+ print(t)
1182
+
1183
+
1184
+ wp.launch(compute, dim=[64], inputs=[])
1185
+
1186
+ Prints:
1187
+
1188
+ .. code-block:: text
1189
+
1190
+ tile(m=1, n=1, storage=register) = [[256]]
1191
+
1192
+
1193
+ """
1194
+ ...
1195
+
1196
+
1197
+ @over
1198
+ def tile_min(a: Tile) -> Tile:
1199
+ """Cooperatively compute the minimum of the tile elements using all threads in the block.
1200
+
1201
+ :param a: The tile to compute the minimum of
1202
+ :returns: A single-element tile with dimensions of (1,1) holding the minimum value
1203
+
1204
+ Example:
1205
+
1206
+ .. code-block:: python
1207
+
1208
+ @wp.kernel
1209
+ def compute():
1210
+ t = wp.tile_arange(start=--10, stop=10, dtype=float)
1211
+ s = wp.tile_min(t)
1212
+
1213
+ print(t)
1214
+
1215
+
1216
+ wp.launch(compute, dim=[64], inputs=[])
1217
+
1218
+ Prints:
1219
+
1220
+ .. code-block:: text
1221
+
1222
+ tile(m=1, n=1, storage=register) = [[-10]]
1223
+
1224
+
1225
+ """
1226
+ ...
1227
+
1228
+
1229
+ @over
1230
+ def tile_max(a: Tile) -> Tile:
1231
+ """Cooperatively compute the maximum of the tile elements using all threads in the block.
1232
+
1233
+ :param a: The tile to compute the maximum from
1234
+ :returns: A single-element tile with dimensions of (1,1) holding the maximum value
1235
+
1236
+ Example:
1237
+
1238
+ .. code-block:: python
1239
+
1240
+ @wp.kernel
1241
+ def compute():
1242
+ t = wp.tile_arange(start=--10, stop=10, dtype=float)
1243
+ s = wp.tile_min(t)
1244
+
1245
+ print(t)
1246
+
1247
+
1248
+ wp.launch(compute, dim=[64], inputs=[])
1249
+
1250
+ Prints:
1251
+
1252
+ .. code-block:: text
1253
+
1254
+ tile(m=1, n=1, storage=register) = [[10]]
1255
+
1256
+
1257
+ """
1258
+ ...
1259
+
1260
+
1261
+ @over
1262
+ def tile_reduce(op: Callable, a: Any) -> Tile:
1263
+ """Apply a custom reduction operator across the tile.
1264
+
1265
+ This function cooperatively performs a reduction using the provided operator across the tile.
1266
+
1267
+ :param op: A callable function that accepts two arguments and returns one argument, may be a user function or builtin
1268
+ :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
1269
+ :returns: A single-element tile with ``shape=(1,1)`` with the same datatype as the input tile.
1270
+
1271
+ Example:
1272
+
1273
+ .. code-block:: python
1274
+
1275
+ @wp.kernel
1276
+ def factorial():
1277
+ t = wp.tile_arange(1, 10, dtype=int)
1278
+ s = wp.tile_reduce(wp.mul, t)
1279
+
1280
+ print(s)
1281
+
1282
+
1283
+ wp.launch(factorial, dim=[16], inputs=[], block_dim=16)
1284
+
1285
+ Prints:
1286
+
1287
+ .. code-block:: text
1288
+
1289
+ tile(m=1, n=1, storage=register) = [[362880]]
1290
+
1291
+ """
1292
+ ...
1293
+
1294
+
1295
+ @over
1296
+ def tile_map(op: Callable, a: Any) -> Tile:
1297
+ """Apply a unary function onto the tile.
1298
+
1299
+ This function cooperatively applies a unary function to each element of the tile using all threads in the block.
1300
+
1301
+ :param op: A callable function that accepts one argument and returns one argument, may be a user function or builtin
1302
+ :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
1303
+ :returns: A tile with the same dimensions and datatype as the input tile.
1304
+
1305
+ Example:
1306
+
1307
+ .. code-block:: python
1308
+
1309
+ @wp.kernel
1310
+ def compute():
1311
+ t = wp.tile_arange(0.0, 1.0, 0.1, dtype=float)
1312
+ s = wp.tile_map(wp.sin, t)
1313
+
1314
+ print(s)
1315
+
1316
+
1317
+ wp.launch(compute, dim=[16], inputs=[])
1318
+
1319
+ Prints:
1320
+
1321
+ .. code-block:: text
1322
+
1323
+ tile(m=1, n=10, storage=register) = [[0 0.0998334 0.198669 0.29552 ...]]
1324
+
1325
+ """
1326
+ ...
1327
+
1328
+
1329
+ @over
1330
+ def tile_map(op: Callable, a: Any, b: Any) -> Tile:
1331
+ """Apply a binary function onto the tile.
1332
+
1333
+ This function cooperatively applies a binary function to each element of the tiles using all threads in the block.
1334
+ Both input tiles must have the same dimensions and datatype.
1335
+
1336
+ :param op: A callable function that accepts two arguments and returns one argument, all of the same type, may be a user function or builtin
1337
+ :param a: The first input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
1338
+ :param b: The second input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
1339
+ :returns: A tile with the same dimensions and datatype as the input tiles.
1340
+
1341
+ Example:
1342
+
1343
+ .. code-block:: python
1344
+
1345
+ @wp.kernel
1346
+ def compute():
1347
+ a = wp.tile_arange(0.0, 1.0, 0.1, dtype=float)
1348
+ b = wp.tile_ones(m=1, n=10, dtype=float)
1349
+
1350
+ s = wp.tile_map(wp.add, a, b)
1351
+
1352
+ print(s)
1353
+
1354
+
1355
+ wp.launch(compute, dim=[16], inputs=[])
1356
+
1357
+ Prints:
1358
+
1359
+ .. code-block:: text
1360
+
1361
+ tile(m=1, n=10, storage=register) = [[1 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9]]
1362
+ """
1363
+ ...
1364
+
1365
+
895
1366
  @over
896
1367
  def mlp(
897
1368
  weights: Array[float32],
@@ -1160,6 +1631,12 @@ def closest_point_edge_edge(p1: vec3f, q1: vec3f, p2: vec3f, q2: vec3f, epsilon:
1160
1631
  ...
1161
1632
 
1162
1633
 
1634
+ @over
1635
+ def reversed(range: range_t) -> range_t:
1636
+ """Returns the range in reversed order."""
1637
+ ...
1638
+
1639
+
1163
1640
  @over
1164
1641
  def volume_sample(id: uint64, uvw: vec3f, sampling_mode: int32, dtype: Any) -> Any:
1165
1642
  """Sample the volume of type `dtype` given by ``id`` at the volume local-space point ``uvw``.
@@ -1374,7 +1851,7 @@ def randf(state: uint32, low: float32, high: float32) -> float:
1374
1851
 
1375
1852
  @over
1376
1853
  def randn(state: uint32) -> float:
1377
- """Sample a normal distribution."""
1854
+ """Sample a normal (Gaussian) distribution of mean 0 and variance 1."""
1378
1855
  ...
1379
1856
 
1380
1857
 
@@ -1638,361 +2115,361 @@ def select(arr: Array[Any], value_if_false: Any, value_if_true: Any) -> Any:
1638
2115
 
1639
2116
 
1640
2117
  @over
1641
- def atomic_add(arr: Array[Any], i: int32, value: Any) -> Any:
2118
+ def atomic_add(arr: Array[Any], i: Int, value: Any) -> Any:
1642
2119
  """Atomically add ``value`` onto ``arr[i]`` and return the old value."""
1643
2120
  ...
1644
2121
 
1645
2122
 
1646
2123
  @over
1647
- def atomic_add(arr: Array[Any], i: int32, j: int32, value: Any) -> Any:
2124
+ def atomic_add(arr: Array[Any], i: Int, j: Int, value: Any) -> Any:
1648
2125
  """Atomically add ``value`` onto ``arr[i,j]`` and return the old value."""
1649
2126
  ...
1650
2127
 
1651
2128
 
1652
2129
  @over
1653
- def atomic_add(arr: Array[Any], i: int32, j: int32, k: int32, value: Any) -> Any:
2130
+ def atomic_add(arr: Array[Any], i: Int, j: Int, k: Int, value: Any) -> Any:
1654
2131
  """Atomically add ``value`` onto ``arr[i,j,k]`` and return the old value."""
1655
2132
  ...
1656
2133
 
1657
2134
 
1658
2135
  @over
1659
- def atomic_add(arr: Array[Any], i: int32, j: int32, k: int32, l: int32, value: Any) -> Any:
2136
+ def atomic_add(arr: Array[Any], i: Int, j: Int, k: Int, l: Int, value: Any) -> Any:
1660
2137
  """Atomically add ``value`` onto ``arr[i,j,k,l]`` and return the old value."""
1661
2138
  ...
1662
2139
 
1663
2140
 
1664
2141
  @over
1665
- def atomic_add(arr: FabricArray[Any], i: int32, value: Any) -> Any:
2142
+ def atomic_add(arr: FabricArray[Any], i: Int, value: Any) -> Any:
1666
2143
  """Atomically add ``value`` onto ``arr[i]`` and return the old value."""
1667
2144
  ...
1668
2145
 
1669
2146
 
1670
2147
  @over
1671
- def atomic_add(arr: FabricArray[Any], i: int32, j: int32, value: Any) -> Any:
2148
+ def atomic_add(arr: FabricArray[Any], i: Int, j: Int, value: Any) -> Any:
1672
2149
  """Atomically add ``value`` onto ``arr[i,j]`` and return the old value."""
1673
2150
  ...
1674
2151
 
1675
2152
 
1676
2153
  @over
1677
- def atomic_add(arr: FabricArray[Any], i: int32, j: int32, k: int32, value: Any) -> Any:
2154
+ def atomic_add(arr: FabricArray[Any], i: Int, j: Int, k: Int, value: Any) -> Any:
1678
2155
  """Atomically add ``value`` onto ``arr[i,j,k]`` and return the old value."""
1679
2156
  ...
1680
2157
 
1681
2158
 
1682
2159
  @over
1683
- def atomic_add(arr: FabricArray[Any], i: int32, j: int32, k: int32, l: int32, value: Any) -> Any:
2160
+ def atomic_add(arr: FabricArray[Any], i: Int, j: Int, k: Int, l: Int, value: Any) -> Any:
1684
2161
  """Atomically add ``value`` onto ``arr[i,j,k,l]`` and return the old value."""
1685
2162
  ...
1686
2163
 
1687
2164
 
1688
2165
  @over
1689
- def atomic_add(arr: IndexedFabricArray[Any], i: int32, value: Any) -> Any:
2166
+ def atomic_add(arr: IndexedFabricArray[Any], i: Int, value: Any) -> Any:
1690
2167
  """Atomically add ``value`` onto ``arr[i]`` and return the old value."""
1691
2168
  ...
1692
2169
 
1693
2170
 
1694
2171
  @over
1695
- def atomic_add(arr: IndexedFabricArray[Any], i: int32, j: int32, value: Any) -> Any:
2172
+ def atomic_add(arr: IndexedFabricArray[Any], i: Int, j: Int, value: Any) -> Any:
1696
2173
  """Atomically add ``value`` onto ``arr[i,j]`` and return the old value."""
1697
2174
  ...
1698
2175
 
1699
2176
 
1700
2177
  @over
1701
- def atomic_add(arr: IndexedFabricArray[Any], i: int32, j: int32, k: int32, value: Any) -> Any:
2178
+ def atomic_add(arr: IndexedFabricArray[Any], i: Int, j: Int, k: Int, value: Any) -> Any:
1702
2179
  """Atomically add ``value`` onto ``arr[i,j,k]`` and return the old value."""
1703
2180
  ...
1704
2181
 
1705
2182
 
1706
2183
  @over
1707
- def atomic_add(arr: IndexedFabricArray[Any], i: int32, j: int32, k: int32, l: int32, value: Any) -> Any:
2184
+ def atomic_add(arr: IndexedFabricArray[Any], i: Int, j: Int, k: Int, l: Int, value: Any) -> Any:
1708
2185
  """Atomically add ``value`` onto ``arr[i,j,k,l]`` and return the old value."""
1709
2186
  ...
1710
2187
 
1711
2188
 
1712
2189
  @over
1713
- def atomic_sub(arr: Array[Any], i: int32, value: Any) -> Any:
2190
+ def atomic_sub(arr: Array[Any], i: Int, value: Any) -> Any:
1714
2191
  """Atomically subtract ``value`` onto ``arr[i]`` and return the old value."""
1715
2192
  ...
1716
2193
 
1717
2194
 
1718
2195
  @over
1719
- def atomic_sub(arr: Array[Any], i: int32, j: int32, value: Any) -> Any:
2196
+ def atomic_sub(arr: Array[Any], i: Int, j: Int, value: Any) -> Any:
1720
2197
  """Atomically subtract ``value`` onto ``arr[i,j]`` and return the old value."""
1721
2198
  ...
1722
2199
 
1723
2200
 
1724
2201
  @over
1725
- def atomic_sub(arr: Array[Any], i: int32, j: int32, k: int32, value: Any) -> Any:
2202
+ def atomic_sub(arr: Array[Any], i: Int, j: Int, k: Int, value: Any) -> Any:
1726
2203
  """Atomically subtract ``value`` onto ``arr[i,j,k]`` and return the old value."""
1727
2204
  ...
1728
2205
 
1729
2206
 
1730
2207
  @over
1731
- def atomic_sub(arr: Array[Any], i: int32, j: int32, k: int32, l: int32, value: Any) -> Any:
2208
+ def atomic_sub(arr: Array[Any], i: Int, j: Int, k: Int, l: Int, value: Any) -> Any:
1732
2209
  """Atomically subtract ``value`` onto ``arr[i,j,k,l]`` and return the old value."""
1733
2210
  ...
1734
2211
 
1735
2212
 
1736
2213
  @over
1737
- def atomic_sub(arr: FabricArray[Any], i: int32, value: Any) -> Any:
2214
+ def atomic_sub(arr: FabricArray[Any], i: Int, value: Any) -> Any:
1738
2215
  """Atomically subtract ``value`` onto ``arr[i]`` and return the old value."""
1739
2216
  ...
1740
2217
 
1741
2218
 
1742
2219
  @over
1743
- def atomic_sub(arr: FabricArray[Any], i: int32, j: int32, value: Any) -> Any:
2220
+ def atomic_sub(arr: FabricArray[Any], i: Int, j: Int, value: Any) -> Any:
1744
2221
  """Atomically subtract ``value`` onto ``arr[i,j]`` and return the old value."""
1745
2222
  ...
1746
2223
 
1747
2224
 
1748
2225
  @over
1749
- def atomic_sub(arr: FabricArray[Any], i: int32, j: int32, k: int32, value: Any) -> Any:
2226
+ def atomic_sub(arr: FabricArray[Any], i: Int, j: Int, k: Int, value: Any) -> Any:
1750
2227
  """Atomically subtract ``value`` onto ``arr[i,j,k]`` and return the old value."""
1751
2228
  ...
1752
2229
 
1753
2230
 
1754
2231
  @over
1755
- def atomic_sub(arr: FabricArray[Any], i: int32, j: int32, k: int32, l: int32, value: Any) -> Any:
2232
+ def atomic_sub(arr: FabricArray[Any], i: Int, j: Int, k: Int, l: Int, value: Any) -> Any:
1756
2233
  """Atomically subtract ``value`` onto ``arr[i,j,k,l]`` and return the old value."""
1757
2234
  ...
1758
2235
 
1759
2236
 
1760
2237
  @over
1761
- def atomic_sub(arr: IndexedFabricArray[Any], i: int32, value: Any) -> Any:
2238
+ def atomic_sub(arr: IndexedFabricArray[Any], i: Int, value: Any) -> Any:
1762
2239
  """Atomically subtract ``value`` onto ``arr[i]`` and return the old value."""
1763
2240
  ...
1764
2241
 
1765
2242
 
1766
2243
  @over
1767
- def atomic_sub(arr: IndexedFabricArray[Any], i: int32, j: int32, value: Any) -> Any:
2244
+ def atomic_sub(arr: IndexedFabricArray[Any], i: Int, j: Int, value: Any) -> Any:
1768
2245
  """Atomically subtract ``value`` onto ``arr[i,j]`` and return the old value."""
1769
2246
  ...
1770
2247
 
1771
2248
 
1772
2249
  @over
1773
- def atomic_sub(arr: IndexedFabricArray[Any], i: int32, j: int32, k: int32, value: Any) -> Any:
2250
+ def atomic_sub(arr: IndexedFabricArray[Any], i: Int, j: Int, k: Int, value: Any) -> Any:
1774
2251
  """Atomically subtract ``value`` onto ``arr[i,j,k]`` and return the old value."""
1775
2252
  ...
1776
2253
 
1777
2254
 
1778
2255
  @over
1779
- def atomic_sub(arr: IndexedFabricArray[Any], i: int32, j: int32, k: int32, l: int32, value: Any) -> Any:
2256
+ def atomic_sub(arr: IndexedFabricArray[Any], i: Int, j: Int, k: Int, l: Int, value: Any) -> Any:
1780
2257
  """Atomically subtract ``value`` onto ``arr[i,j,k,l]`` and return the old value."""
1781
2258
  ...
1782
2259
 
1783
2260
 
1784
2261
  @over
1785
- def atomic_min(arr: Array[Any], i: int32, value: Any) -> Any:
2262
+ def atomic_min(arr: Array[Any], i: Int, value: Any) -> Any:
1786
2263
  """Compute the minimum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
1787
2264
 
1788
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2265
+ The operation is only atomic on a per-component basis for vectors and matrices.
1789
2266
  """
1790
2267
  ...
1791
2268
 
1792
2269
 
1793
2270
  @over
1794
- def atomic_min(arr: Array[Any], i: int32, j: int32, value: Any) -> Any:
2271
+ def atomic_min(arr: Array[Any], i: Int, j: Int, value: Any) -> Any:
1795
2272
  """Compute the minimum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
1796
2273
 
1797
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2274
+ The operation is only atomic on a per-component basis for vectors and matrices.
1798
2275
  """
1799
2276
  ...
1800
2277
 
1801
2278
 
1802
2279
  @over
1803
- def atomic_min(arr: Array[Any], i: int32, j: int32, k: int32, value: Any) -> Any:
2280
+ def atomic_min(arr: Array[Any], i: Int, j: Int, k: Int, value: Any) -> Any:
1804
2281
  """Compute the minimum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
1805
2282
 
1806
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2283
+ The operation is only atomic on a per-component basis for vectors and matrices.
1807
2284
  """
1808
2285
  ...
1809
2286
 
1810
2287
 
1811
2288
  @over
1812
- def atomic_min(arr: Array[Any], i: int32, j: int32, k: int32, l: int32, value: Any) -> Any:
2289
+ def atomic_min(arr: Array[Any], i: Int, j: Int, k: Int, l: Int, value: Any) -> Any:
1813
2290
  """Compute the minimum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
1814
2291
 
1815
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2292
+ The operation is only atomic on a per-component basis for vectors and matrices.
1816
2293
  """
1817
2294
  ...
1818
2295
 
1819
2296
 
1820
2297
  @over
1821
- def atomic_min(arr: FabricArray[Any], i: int32, value: Any) -> Any:
2298
+ def atomic_min(arr: FabricArray[Any], i: Int, value: Any) -> Any:
1822
2299
  """Compute the minimum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
1823
2300
 
1824
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2301
+ The operation is only atomic on a per-component basis for vectors and matrices.
1825
2302
  """
1826
2303
  ...
1827
2304
 
1828
2305
 
1829
2306
  @over
1830
- def atomic_min(arr: FabricArray[Any], i: int32, j: int32, value: Any) -> Any:
2307
+ def atomic_min(arr: FabricArray[Any], i: Int, j: Int, value: Any) -> Any:
1831
2308
  """Compute the minimum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
1832
2309
 
1833
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2310
+ The operation is only atomic on a per-component basis for vectors and matrices.
1834
2311
  """
1835
2312
  ...
1836
2313
 
1837
2314
 
1838
2315
  @over
1839
- def atomic_min(arr: FabricArray[Any], i: int32, j: int32, k: int32, value: Any) -> Any:
2316
+ def atomic_min(arr: FabricArray[Any], i: Int, j: Int, k: Int, value: Any) -> Any:
1840
2317
  """Compute the minimum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
1841
2318
 
1842
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2319
+ The operation is only atomic on a per-component basis for vectors and matrices.
1843
2320
  """
1844
2321
  ...
1845
2322
 
1846
2323
 
1847
2324
  @over
1848
- def atomic_min(arr: FabricArray[Any], i: int32, j: int32, k: int32, l: int32, value: Any) -> Any:
2325
+ def atomic_min(arr: FabricArray[Any], i: Int, j: Int, k: Int, l: Int, value: Any) -> Any:
1849
2326
  """Compute the minimum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
1850
2327
 
1851
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2328
+ The operation is only atomic on a per-component basis for vectors and matrices.
1852
2329
  """
1853
2330
  ...
1854
2331
 
1855
2332
 
1856
2333
  @over
1857
- def atomic_min(arr: IndexedFabricArray[Any], i: int32, value: Any) -> Any:
2334
+ def atomic_min(arr: IndexedFabricArray[Any], i: Int, value: Any) -> Any:
1858
2335
  """Compute the minimum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
1859
2336
 
1860
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2337
+ The operation is only atomic on a per-component basis for vectors and matrices.
1861
2338
  """
1862
2339
  ...
1863
2340
 
1864
2341
 
1865
2342
  @over
1866
- def atomic_min(arr: IndexedFabricArray[Any], i: int32, j: int32, value: Any) -> Any:
2343
+ def atomic_min(arr: IndexedFabricArray[Any], i: Int, j: Int, value: Any) -> Any:
1867
2344
  """Compute the minimum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
1868
2345
 
1869
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2346
+ The operation is only atomic on a per-component basis for vectors and matrices.
1870
2347
  """
1871
2348
  ...
1872
2349
 
1873
2350
 
1874
2351
  @over
1875
- def atomic_min(arr: IndexedFabricArray[Any], i: int32, j: int32, k: int32, value: Any) -> Any:
2352
+ def atomic_min(arr: IndexedFabricArray[Any], i: Int, j: Int, k: Int, value: Any) -> Any:
1876
2353
  """Compute the minimum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
1877
2354
 
1878
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2355
+ The operation is only atomic on a per-component basis for vectors and matrices.
1879
2356
  """
1880
2357
  ...
1881
2358
 
1882
2359
 
1883
2360
  @over
1884
- def atomic_min(arr: IndexedFabricArray[Any], i: int32, j: int32, k: int32, l: int32, value: Any) -> Any:
2361
+ def atomic_min(arr: IndexedFabricArray[Any], i: Int, j: Int, k: Int, l: Int, value: Any) -> Any:
1885
2362
  """Compute the minimum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
1886
2363
 
1887
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2364
+ The operation is only atomic on a per-component basis for vectors and matrices.
1888
2365
  """
1889
2366
  ...
1890
2367
 
1891
2368
 
1892
2369
  @over
1893
- def atomic_max(arr: Array[Any], i: int32, value: Any) -> Any:
2370
+ def atomic_max(arr: Array[Any], i: Int, value: Any) -> Any:
1894
2371
  """Compute the maximum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
1895
2372
 
1896
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2373
+ The operation is only atomic on a per-component basis for vectors and matrices.
1897
2374
  """
1898
2375
  ...
1899
2376
 
1900
2377
 
1901
2378
  @over
1902
- def atomic_max(arr: Array[Any], i: int32, j: int32, value: Any) -> Any:
2379
+ def atomic_max(arr: Array[Any], i: Int, j: Int, value: Any) -> Any:
1903
2380
  """Compute the maximum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
1904
2381
 
1905
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2382
+ The operation is only atomic on a per-component basis for vectors and matrices.
1906
2383
  """
1907
2384
  ...
1908
2385
 
1909
2386
 
1910
2387
  @over
1911
- def atomic_max(arr: Array[Any], i: int32, j: int32, k: int32, value: Any) -> Any:
2388
+ def atomic_max(arr: Array[Any], i: Int, j: Int, k: Int, value: Any) -> Any:
1912
2389
  """Compute the maximum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
1913
2390
 
1914
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2391
+ The operation is only atomic on a per-component basis for vectors and matrices.
1915
2392
  """
1916
2393
  ...
1917
2394
 
1918
2395
 
1919
2396
  @over
1920
- def atomic_max(arr: Array[Any], i: int32, j: int32, k: int32, l: int32, value: Any) -> Any:
2397
+ def atomic_max(arr: Array[Any], i: Int, j: Int, k: Int, l: Int, value: Any) -> Any:
1921
2398
  """Compute the maximum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
1922
2399
 
1923
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2400
+ The operation is only atomic on a per-component basis for vectors and matrices.
1924
2401
  """
1925
2402
  ...
1926
2403
 
1927
2404
 
1928
2405
  @over
1929
- def atomic_max(arr: FabricArray[Any], i: int32, value: Any) -> Any:
2406
+ def atomic_max(arr: FabricArray[Any], i: Int, value: Any) -> Any:
1930
2407
  """Compute the maximum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
1931
2408
 
1932
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2409
+ The operation is only atomic on a per-component basis for vectors and matrices.
1933
2410
  """
1934
2411
  ...
1935
2412
 
1936
2413
 
1937
2414
  @over
1938
- def atomic_max(arr: FabricArray[Any], i: int32, j: int32, value: Any) -> Any:
2415
+ def atomic_max(arr: FabricArray[Any], i: Int, j: Int, value: Any) -> Any:
1939
2416
  """Compute the maximum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
1940
2417
 
1941
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2418
+ The operation is only atomic on a per-component basis for vectors and matrices.
1942
2419
  """
1943
2420
  ...
1944
2421
 
1945
2422
 
1946
2423
  @over
1947
- def atomic_max(arr: FabricArray[Any], i: int32, j: int32, k: int32, value: Any) -> Any:
2424
+ def atomic_max(arr: FabricArray[Any], i: Int, j: Int, k: Int, value: Any) -> Any:
1948
2425
  """Compute the maximum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
1949
2426
 
1950
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2427
+ The operation is only atomic on a per-component basis for vectors and matrices.
1951
2428
  """
1952
2429
  ...
1953
2430
 
1954
2431
 
1955
2432
  @over
1956
- def atomic_max(arr: FabricArray[Any], i: int32, j: int32, k: int32, l: int32, value: Any) -> Any:
2433
+ def atomic_max(arr: FabricArray[Any], i: Int, j: Int, k: Int, l: Int, value: Any) -> Any:
1957
2434
  """Compute the maximum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
1958
2435
 
1959
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2436
+ The operation is only atomic on a per-component basis for vectors and matrices.
1960
2437
  """
1961
2438
  ...
1962
2439
 
1963
2440
 
1964
2441
  @over
1965
- def atomic_max(arr: IndexedFabricArray[Any], i: int32, value: Any) -> Any:
2442
+ def atomic_max(arr: IndexedFabricArray[Any], i: Int, value: Any) -> Any:
1966
2443
  """Compute the maximum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
1967
2444
 
1968
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2445
+ The operation is only atomic on a per-component basis for vectors and matrices.
1969
2446
  """
1970
2447
  ...
1971
2448
 
1972
2449
 
1973
2450
  @over
1974
- def atomic_max(arr: IndexedFabricArray[Any], i: int32, j: int32, value: Any) -> Any:
2451
+ def atomic_max(arr: IndexedFabricArray[Any], i: Int, j: Int, value: Any) -> Any:
1975
2452
  """Compute the maximum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
1976
2453
 
1977
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2454
+ The operation is only atomic on a per-component basis for vectors and matrices.
1978
2455
  """
1979
2456
  ...
1980
2457
 
1981
2458
 
1982
2459
  @over
1983
- def atomic_max(arr: IndexedFabricArray[Any], i: int32, j: int32, k: int32, value: Any) -> Any:
2460
+ def atomic_max(arr: IndexedFabricArray[Any], i: Int, j: Int, k: Int, value: Any) -> Any:
1984
2461
  """Compute the maximum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
1985
2462
 
1986
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2463
+ The operation is only atomic on a per-component basis for vectors and matrices.
1987
2464
  """
1988
2465
  ...
1989
2466
 
1990
2467
 
1991
2468
  @over
1992
- def atomic_max(arr: IndexedFabricArray[Any], i: int32, j: int32, k: int32, l: int32, value: Any) -> Any:
2469
+ def atomic_max(arr: IndexedFabricArray[Any], i: Int, j: Int, k: Int, l: Int, value: Any) -> Any:
1993
2470
  """Compute the maximum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
1994
2471
 
1995
- .. note:: The operation is only atomic on a per-component basis for vectors and matrices.
2472
+ The operation is only atomic on a per-component basis for vectors and matrices.
1996
2473
  """
1997
2474
  ...
1998
2475
 
@@ -2089,6 +2566,12 @@ def add(a: Transformation[Scalar], b: Transformation[Scalar]) -> Transformation[
2089
2566
  ...
2090
2567
 
2091
2568
 
2569
+ @over
2570
+ def add(a: Tile, b: Tile) -> Tile:
2571
+ """Add each element of two tiles together"""
2572
+ ...
2573
+
2574
+
2092
2575
  @over
2093
2576
  def sub(a: Scalar, b: Scalar) -> Scalar:
2094
2577
  """ """
@@ -2239,6 +2722,18 @@ def mul(a: Transformation[Scalar], b: Scalar) -> Transformation[Scalar]:
2239
2722
  ...
2240
2723
 
2241
2724
 
2725
+ @over
2726
+ def mul(x: Tile, y: Scalar) -> Tile:
2727
+ """Multiply each element of a tile by a scalar"""
2728
+ ...
2729
+
2730
+
2731
+ @over
2732
+ def mul(x: Scalar, y: Tile) -> Tile:
2733
+ """Multiply each element of a tile by a scalar"""
2734
+ ...
2735
+
2736
+
2242
2737
  @over
2243
2738
  def mod(a: Scalar, b: Scalar) -> Scalar:
2244
2739
  """Modulo operation using truncated division."""
@@ -2347,6 +2842,12 @@ def neg(x: Matrix[Any, Any, Scalar]) -> Matrix[Any, Any, Scalar]:
2347
2842
  ...
2348
2843
 
2349
2844
 
2845
+ @over
2846
+ def neg(x: Tile) -> Tile:
2847
+ """Negate each element of a tile"""
2848
+ ...
2849
+
2850
+
2350
2851
  @over
2351
2852
  def unot(a: bool) -> bool:
2352
2853
  """ """
@@ -2407,16 +2908,81 @@ def unot(a: Array[Any]) -> bool:
2407
2908
  ...
2408
2909
 
2409
2910
 
2911
+ @over
2912
+ def tile_matmul(a: Tile, b: Tile, out: Tile) -> Tile:
2913
+ """Computes the matrix product and accumulates ``out += a*b``.
2914
+
2915
+ Supported datatypes are:
2916
+ * fp16, fp32, fp64 (real)
2917
+ * vec2h, vec2f, vec2d (complex)
2918
+
2919
+ All input and output tiles must have the same datatype. Tile data will be automatically be migrated
2920
+ to shared memory if necessary and will use TensorCore operations when available.
2921
+
2922
+ :param a: A tile with ``shape=(M, K)``
2923
+ :param b: A tile with ``shape=(K, N)``
2924
+ :param out: A tile with ``shape=(M, N)``
2925
+
2926
+ """
2927
+ ...
2928
+
2929
+
2930
+ @over
2931
+ def tile_matmul(a: Tile, b: Tile) -> Tile:
2932
+ """Computes the matrix product ``out = a*b``.
2933
+
2934
+ Supported datatypes are:
2935
+ * fp16, fp32, fp64 (real)
2936
+ * vec2h, vec2f, vec2d (complex)
2937
+
2938
+ Both input tiles must have the same datatype. Tile data will be automatically be migrated
2939
+ to shared memory if necessary and will use TensorCore operations when available.
2940
+
2941
+ :param a: A tile with ``shape=(M, K)``
2942
+ :param b: A tile with ``shape=(K, N)``
2943
+ :returns: A tile with ``shape=(M, N)``
2944
+
2945
+ """
2946
+ ...
2947
+
2948
+
2949
+ @over
2950
+ def tile_fft(inout: Tile) -> Tile:
2951
+ """Compute the forward FFT along the second dimension of a 2D tile of data.
2952
+
2953
+ This function cooperatively computes the forward FFT on a tile of data inplace, treating each row individually.
2954
+
2955
+ Supported datatypes are:
2956
+ * vec2f, vec2d
2957
+
2958
+ :param inout: The input/output tile
2959
+ """
2960
+ ...
2961
+
2962
+
2963
+ @over
2964
+ def tile_ifft(inout: Tile) -> Tile:
2965
+ """Compute the inverse FFT along the second dimension of a 2D tile of data.
2966
+
2967
+ This function cooperatively computes the inverse FFT on a tile of data inplace, treating each row individually.
2968
+
2969
+ Supported datatypes are:
2970
+ * vec2f, vec2d
2971
+
2972
+ :param inout: The input/output tile
2973
+ """
2974
+ ...
2975
+
2976
+
2410
2977
  @over
2411
2978
  def static(expr: Any) -> Any:
2412
2979
  """Evaluates a static Python expression and replaces it with its result.
2413
2980
 
2414
- See the `codegen.html#static-expressions <section on code generation>`_ for more details.
2981
+ See the :ref:`code generation guide <static_expressions>` for more details.
2415
2982
 
2416
- Note:
2417
- The inner expression must only reference variables that are available from the current scope where the Warp kernel or function containing the expression is defined,
2418
- which includes constant variables and variables captured in the current closure in which the function or kernel is implemented.
2419
- The return type of the expression must be either a Warp function, a string, or a type that is supported inside Warp kernels and functions
2420
- (excluding Warp arrays since they cannot be created in a Warp kernel at the moment).
2983
+ The inner expression must only reference variables that are available from the current scope where the Warp kernel or function containing the expression is defined,
2984
+ which includes constant variables and variables captured in the current closure in which the function or kernel is implemented.
2985
+ The return type of the expression must be either a Warp function, a string, or a type that is supported inside Warp kernels and functions
2986
+ (excluding Warp arrays since they cannot be created in a Warp kernel at the moment).
2421
2987
  """
2422
2988
  ...