warp-lang 1.7.2rc1__py3-none-win_amd64.whl → 1.8.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (193) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.dll +0 -0
  5. warp/bin/warp.dll +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +130 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +272 -104
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +770 -238
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_callable.py +34 -4
  36. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  37. warp/examples/interop/example_jax_kernel.py +27 -1
  38. warp/examples/optim/example_drone.py +1 -1
  39. warp/examples/sim/example_cloth.py +1 -1
  40. warp/examples/sim/example_cloth_self_contact.py +48 -54
  41. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  42. warp/examples/tile/example_tile_cholesky.py +2 -1
  43. warp/examples/tile/example_tile_convolution.py +1 -1
  44. warp/examples/tile/example_tile_filtering.py +1 -1
  45. warp/examples/tile/example_tile_matmul.py +1 -1
  46. warp/examples/tile/example_tile_mlp.py +2 -0
  47. warp/fabric.py +7 -7
  48. warp/fem/__init__.py +5 -0
  49. warp/fem/adaptivity.py +1 -1
  50. warp/fem/cache.py +152 -63
  51. warp/fem/dirichlet.py +2 -2
  52. warp/fem/domain.py +136 -6
  53. warp/fem/field/field.py +141 -99
  54. warp/fem/field/nodal_field.py +85 -39
  55. warp/fem/field/virtual.py +99 -52
  56. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  57. warp/fem/geometry/closest_point.py +13 -0
  58. warp/fem/geometry/deformed_geometry.py +102 -40
  59. warp/fem/geometry/element.py +56 -2
  60. warp/fem/geometry/geometry.py +323 -22
  61. warp/fem/geometry/grid_2d.py +157 -62
  62. warp/fem/geometry/grid_3d.py +116 -20
  63. warp/fem/geometry/hexmesh.py +86 -20
  64. warp/fem/geometry/nanogrid.py +166 -86
  65. warp/fem/geometry/partition.py +59 -25
  66. warp/fem/geometry/quadmesh.py +86 -135
  67. warp/fem/geometry/tetmesh.py +47 -119
  68. warp/fem/geometry/trimesh.py +77 -270
  69. warp/fem/integrate.py +181 -95
  70. warp/fem/linalg.py +25 -58
  71. warp/fem/operator.py +124 -27
  72. warp/fem/quadrature/pic_quadrature.py +36 -14
  73. warp/fem/quadrature/quadrature.py +40 -16
  74. warp/fem/space/__init__.py +1 -1
  75. warp/fem/space/basis_function_space.py +66 -46
  76. warp/fem/space/basis_space.py +17 -4
  77. warp/fem/space/dof_mapper.py +1 -1
  78. warp/fem/space/function_space.py +2 -2
  79. warp/fem/space/grid_2d_function_space.py +4 -1
  80. warp/fem/space/hexmesh_function_space.py +4 -2
  81. warp/fem/space/nanogrid_function_space.py +3 -1
  82. warp/fem/space/partition.py +11 -2
  83. warp/fem/space/quadmesh_function_space.py +4 -1
  84. warp/fem/space/restriction.py +5 -2
  85. warp/fem/space/shape/__init__.py +10 -8
  86. warp/fem/space/tetmesh_function_space.py +4 -1
  87. warp/fem/space/topology.py +52 -21
  88. warp/fem/space/trimesh_function_space.py +4 -1
  89. warp/fem/utils.py +53 -8
  90. warp/jax.py +1 -2
  91. warp/jax_experimental/ffi.py +210 -67
  92. warp/jax_experimental/xla_ffi.py +37 -24
  93. warp/math.py +171 -1
  94. warp/native/array.h +103 -4
  95. warp/native/builtin.h +182 -35
  96. warp/native/coloring.cpp +6 -2
  97. warp/native/cuda_util.cpp +1 -1
  98. warp/native/exports.h +118 -63
  99. warp/native/intersect.h +5 -5
  100. warp/native/mat.h +8 -13
  101. warp/native/mathdx.cpp +11 -5
  102. warp/native/matnn.h +1 -123
  103. warp/native/mesh.h +1 -1
  104. warp/native/quat.h +34 -6
  105. warp/native/rand.h +7 -7
  106. warp/native/sparse.cpp +121 -258
  107. warp/native/sparse.cu +181 -274
  108. warp/native/spatial.h +305 -17
  109. warp/native/svd.h +23 -8
  110. warp/native/tile.h +603 -73
  111. warp/native/tile_radix_sort.h +1112 -0
  112. warp/native/tile_reduce.h +239 -13
  113. warp/native/tile_scan.h +240 -0
  114. warp/native/tuple.h +189 -0
  115. warp/native/vec.h +10 -20
  116. warp/native/warp.cpp +36 -4
  117. warp/native/warp.cu +588 -52
  118. warp/native/warp.h +47 -74
  119. warp/optim/linear.py +5 -1
  120. warp/paddle.py +7 -8
  121. warp/py.typed +0 -0
  122. warp/render/render_opengl.py +110 -80
  123. warp/render/render_usd.py +124 -62
  124. warp/sim/__init__.py +9 -0
  125. warp/sim/collide.py +253 -80
  126. warp/sim/graph_coloring.py +8 -1
  127. warp/sim/import_mjcf.py +4 -3
  128. warp/sim/import_usd.py +11 -7
  129. warp/sim/integrator.py +5 -2
  130. warp/sim/integrator_euler.py +1 -1
  131. warp/sim/integrator_featherstone.py +1 -1
  132. warp/sim/integrator_vbd.py +761 -322
  133. warp/sim/integrator_xpbd.py +1 -1
  134. warp/sim/model.py +265 -260
  135. warp/sim/utils.py +10 -7
  136. warp/sparse.py +303 -166
  137. warp/tape.py +54 -51
  138. warp/tests/cuda/test_conditional_captures.py +1046 -0
  139. warp/tests/cuda/test_streams.py +1 -1
  140. warp/tests/geometry/test_volume.py +2 -2
  141. warp/tests/interop/test_dlpack.py +9 -9
  142. warp/tests/interop/test_jax.py +0 -1
  143. warp/tests/run_coverage_serial.py +1 -1
  144. warp/tests/sim/disabled_kinematics.py +2 -2
  145. warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
  146. warp/tests/sim/test_collision.py +159 -51
  147. warp/tests/sim/test_coloring.py +91 -2
  148. warp/tests/test_array.py +254 -2
  149. warp/tests/test_array_reduce.py +2 -2
  150. warp/tests/test_assert.py +53 -0
  151. warp/tests/test_atomic_cas.py +312 -0
  152. warp/tests/test_codegen.py +142 -19
  153. warp/tests/test_conditional.py +47 -1
  154. warp/tests/test_ctypes.py +0 -20
  155. warp/tests/test_devices.py +8 -0
  156. warp/tests/test_fabricarray.py +4 -2
  157. warp/tests/test_fem.py +58 -25
  158. warp/tests/test_func.py +42 -1
  159. warp/tests/test_grad.py +1 -1
  160. warp/tests/test_lerp.py +1 -3
  161. warp/tests/test_map.py +481 -0
  162. warp/tests/test_mat.py +23 -24
  163. warp/tests/test_quat.py +28 -15
  164. warp/tests/test_rounding.py +10 -38
  165. warp/tests/test_runlength_encode.py +7 -7
  166. warp/tests/test_smoothstep.py +1 -1
  167. warp/tests/test_sparse.py +83 -2
  168. warp/tests/test_spatial.py +507 -1
  169. warp/tests/test_static.py +48 -0
  170. warp/tests/test_struct.py +2 -2
  171. warp/tests/test_tape.py +38 -0
  172. warp/tests/test_tuple.py +265 -0
  173. warp/tests/test_types.py +2 -2
  174. warp/tests/test_utils.py +24 -18
  175. warp/tests/test_vec.py +38 -408
  176. warp/tests/test_vec_constructors.py +325 -0
  177. warp/tests/tile/test_tile.py +438 -131
  178. warp/tests/tile/test_tile_mathdx.py +518 -14
  179. warp/tests/tile/test_tile_matmul.py +179 -0
  180. warp/tests/tile/test_tile_reduce.py +307 -5
  181. warp/tests/tile/test_tile_shared_memory.py +136 -7
  182. warp/tests/tile/test_tile_sort.py +121 -0
  183. warp/tests/unittest_suites.py +14 -6
  184. warp/types.py +462 -308
  185. warp/utils.py +647 -86
  186. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
  187. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
  188. warp/stubs.py +0 -3381
  189. warp/tests/sim/test_xpbd.py +0 -399
  190. warp/tests/test_mlp.py +0 -282
  191. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  192. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  193. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/types.py CHANGED
@@ -18,20 +18,19 @@ from __future__ import annotations
18
18
  import builtins
19
19
  import ctypes
20
20
  import inspect
21
+ import math
21
22
  import struct
22
23
  import zlib
23
24
  from typing import (
24
25
  Any,
25
26
  Callable,
27
+ ClassVar,
26
28
  Generic,
27
- List,
28
29
  Literal,
29
30
  NamedTuple,
30
- Optional,
31
31
  Sequence,
32
32
  Tuple,
33
33
  TypeVar,
34
- Union,
35
34
  get_args,
36
35
  get_origin,
37
36
  )
@@ -47,6 +46,7 @@ Length = TypeVar("Length", bound=int)
47
46
  Rows = TypeVar("Rows")
48
47
  Cols = TypeVar("Cols")
49
48
  DType = TypeVar("DType")
49
+ Shape = TypeVar("Shape", bound=Tuple[int, ...])
50
50
 
51
51
  Int = TypeVar("Int")
52
52
  Float = TypeVar("Float")
@@ -70,18 +70,96 @@ class Transformation(Generic[Float]):
70
70
 
71
71
 
72
72
  class Array(Generic[DType]):
73
- device: Optional[warp.context.Device]
73
+ device: warp.context.Device | None
74
74
  dtype: type
75
75
  size: int
76
76
 
77
+ def __add__(self, other) -> array:
78
+ return warp.map(warp.add, self, other) # type: ignore
77
79
 
78
- int_tuple_type_hints = {
79
- Tuple[int]: 1,
80
- Tuple[int, int]: 2,
81
- Tuple[int, int, int]: 3,
82
- Tuple[int, int, int, int]: 4,
83
- Tuple[int, ...]: -1,
84
- }
80
+ def __radd__(self, other) -> array:
81
+ return warp.map(warp.add, other, self) # type: ignore
82
+
83
+ def __sub__(self, other) -> array:
84
+ return warp.map(warp.sub, self, other) # type: ignore
85
+
86
+ def __rsub__(self, other) -> array:
87
+ return warp.map(warp.sub, other, self) # type: ignore
88
+
89
+ def __mul__(self, other) -> array:
90
+ return warp.map(warp.mul, self, other) # type: ignore
91
+
92
+ def __rmul__(self, other) -> array:
93
+ return warp.map(warp.mul, other, self) # type: ignore
94
+
95
+ def __truediv__(self, other) -> array:
96
+ return warp.map(warp.div, self, other) # type: ignore
97
+
98
+ def __rtruediv__(self, other) -> array:
99
+ return warp.map(warp.div, other, self) # type: ignore
100
+
101
+ def __floordiv__(self, other) -> array:
102
+ return warp.map(warp.floordiv, self, other) # type: ignore
103
+
104
+ def __rfloordiv__(self, other) -> array:
105
+ return warp.map(warp.floordiv, other, self) # type: ignore
106
+
107
+ def __mod__(self, other) -> array:
108
+ return warp.map(warp.mod, self, other) # type: ignore
109
+
110
+ def __rmod__(self, other) -> array:
111
+ return warp.map(warp.mod, other, self) # type: ignore
112
+
113
+ def __pow__(self, other) -> array:
114
+ return warp.map(warp.pow, self, other) # type: ignore
115
+
116
+ def __rpow__(self, other) -> array:
117
+ return warp.map(warp.pow, other, self) # type: ignore
118
+
119
+ def __neg__(self) -> array:
120
+ return warp.map(warp.neg, self) # type: ignore
121
+
122
+ def __pos__(self) -> array:
123
+ return warp.map(warp.pos, self) # type: ignore
124
+
125
+ def __iadd__(self, other):
126
+ """In-place addition operator."""
127
+ warp.map(warp.add, self, other, out=self)
128
+ return self
129
+
130
+ def __isub__(self, other):
131
+ """In-place subtraction operator."""
132
+ warp.map(warp.sub, self, other, out=self)
133
+ return self
134
+
135
+ def __imul__(self, other):
136
+ """In-place multiplication operator."""
137
+ warp.map(warp.mul, self, other, out=self)
138
+ return self
139
+
140
+ def __itruediv__(self, other):
141
+ """In-place division operator."""
142
+ warp.map(warp.div, self, other, out=self)
143
+ return self
144
+
145
+ def __ifloordiv__(self, other):
146
+ """In-place floor division operator."""
147
+ warp.map(warp.floordiv, self, other, out=self)
148
+ return self
149
+
150
+ def __imod__(self, other):
151
+ """In-place modulo operator."""
152
+ warp.map(warp.mod, self, other, out=self)
153
+ return self
154
+
155
+ def __ipow__(self, other):
156
+ """In-place power operator."""
157
+ warp.map(warp.pow, self, other, out=self)
158
+ return self
159
+
160
+
161
+ class Tile(Generic[DType, Shape]):
162
+ pass
85
163
 
86
164
 
87
165
  def constant(x):
@@ -105,6 +183,13 @@ def half_bits_to_float(value):
105
183
  return warp.context.runtime.core.half_bits_to_float(value)
106
184
 
107
185
 
186
+ def safe_len(obj):
187
+ try:
188
+ return len(obj)
189
+ except TypeError:
190
+ return -1
191
+
192
+
108
193
  # ----------------------
109
194
  # built-in types
110
195
 
@@ -134,8 +219,8 @@ def vector(length, dtype):
134
219
 
135
220
  # warp scalar type:
136
221
  _wp_scalar_type_ = dtype
137
- _wp_type_params_ = [length, dtype]
138
- _wp_type_args_ = {"length": length, "dtype": dtype}
222
+ _wp_type_params_ = (length, dtype)
223
+ _wp_type_args_: ClassVar[dict[str, Any]] = {"length": length, "dtype": dtype}
139
224
  _wp_generic_type_str_ = "vec_t"
140
225
  _wp_generic_type_hint_ = Vector
141
226
  _wp_constructor_ = "vector"
@@ -282,7 +367,7 @@ def vector(length, dtype):
282
367
  return f"{type_repr(self)}([{', '.join(map(repr, self))}])"
283
368
 
284
369
  def __eq__(self, other):
285
- if self._length_ != len(other):
370
+ if self._length_ != safe_len(other):
286
371
  return False
287
372
  for i in range(self._length_):
288
373
  if self[i] != other[i]:
@@ -330,8 +415,8 @@ def matrix(shape, dtype):
330
415
  # warp scalar type:
331
416
  # used in type checking and when writing out c++ code for constructors:
332
417
  _wp_scalar_type_ = dtype
333
- _wp_type_params_ = [shape[0], shape[1], dtype]
334
- _wp_type_args_ = {"shape": (shape[0], shape[1]), "dtype": dtype}
418
+ _wp_type_params_ = (shape[0], shape[1], dtype)
419
+ _wp_type_args_: ClassVar[dict[str, Any]] = {"shape": (shape[0], shape[1]), "dtype": dtype}
335
420
  _wp_generic_type_str_ = "mat_t"
336
421
  _wp_generic_type_hint_ = Matrix
337
422
  _wp_constructor_ = "matrix"
@@ -426,10 +511,10 @@ def matrix(shape, dtype):
426
511
  return "[" + ",\n ".join(row_str) + "]"
427
512
 
428
513
  def __eq__(self, other):
429
- if self._shape_[0] != len(other):
514
+ if self._shape_[0] != safe_len(other):
430
515
  return False
431
516
  for i in range(self._shape_[0]):
432
- if self._shape_[1] != len(other[i]):
517
+ if self._shape_[1] != safe_len(other[i]):
433
518
  return False
434
519
  for j in range(self._shape_[1]):
435
520
  if self[i][j] != other[i][j]:
@@ -703,15 +788,15 @@ def transformation(dtype=Any):
703
788
  ),
704
789
  ),
705
790
  )
706
- _wp_type_params_ = [dtype]
707
- _wp_type_args_ = {"dtype": dtype}
791
+ _wp_type_params_ = (dtype,)
792
+ _wp_type_args_: ClassVar[dict[str, type]] = {"dtype": dtype}
708
793
  _wp_generic_type_str_ = "transform_t"
709
794
  _wp_generic_type_hint_ = Transformation
710
795
  _wp_constructor_ = "transformation"
711
796
 
712
797
  def __init__(self, *args, **kwargs):
713
798
  if len(args) == 1 and len(kwargs) == 0:
714
- if is_float(args[0]):
799
+ if is_float(args[0]) or is_int(args[0]):
715
800
  # Initialize from a single scalar.
716
801
  super().__init__(args[0])
717
802
  return
@@ -745,13 +830,26 @@ def transformation(dtype=Any):
745
830
  # Fallback to the vector's constructor.
746
831
  super().__init__(*args)
747
832
 
748
- @property
749
- def p(self):
750
- return vec3(self[0:3])
833
+ def __getattr__(self, name):
834
+ if name == "p":
835
+ return vec3(self[0:3])
836
+ elif name == "q":
837
+ return quat(self[3:7])
838
+ else:
839
+ return self.__getattribute__(name)
751
840
 
752
- @property
753
- def q(self):
754
- return quat(self[3:7])
841
+ def __setattr__(self, name, value):
842
+ if name == "p":
843
+ self[0:3] = vector(length=3, dtype=dtype)(*value)
844
+ elif name == "q":
845
+ self[3:7] = quaternion(dtype=dtype)(*value)
846
+ else:
847
+ # we don't permit vector xyzw indexing for transformations
848
+ idx = "xyzw".find(name)
849
+ if idx != -1:
850
+ raise RuntimeError(f"Cannot set attribute {name} of transformation")
851
+ else:
852
+ super().__setattr__(name, value)
755
853
 
756
854
  return transform_t
757
855
 
@@ -976,7 +1074,7 @@ spatial_matrix = spatial_matrixf
976
1074
  int_types = (int8, uint8, int16, uint16, int32, uint32, int64, uint64)
977
1075
  float_types = (float16, float32, float64)
978
1076
  scalar_types = int_types + float_types
979
- scalar_and_bool_types = scalar_types + (bool,)
1077
+ scalar_and_bool_types = (*scalar_types, bool)
980
1078
 
981
1079
  vector_types = (
982
1080
  vec2b,
@@ -1150,36 +1248,24 @@ class range_t:
1150
1248
 
1151
1249
 
1152
1250
  # definition just for kernel type (cannot be a parameter), see bvh.h
1153
- class bvh_query_t:
1251
+ class BvhQuery:
1154
1252
  """Object used to track state during BVH traversal."""
1155
1253
 
1156
- def __init__(self):
1157
- pass
1158
-
1159
-
1160
- BvhQuery = bvh_query_t
1254
+ _wp_native_name_ = "bvh_query_t"
1161
1255
 
1162
1256
 
1163
1257
  # definition just for kernel type (cannot be a parameter), see mesh.h
1164
- class mesh_query_aabb_t:
1258
+ class MeshQueryAABB:
1165
1259
  """Object used to track state during mesh traversal."""
1166
1260
 
1167
- def __init__(self):
1168
- pass
1169
-
1170
-
1171
- MeshQueryAABB = mesh_query_aabb_t
1261
+ _wp_native_name_ = "mesh_query_aabb_t"
1172
1262
 
1173
1263
 
1174
1264
  # definition just for kernel type (cannot be a parameter), see hash_grid.h
1175
- class hash_grid_query_t:
1265
+ class HashGridQuery:
1176
1266
  """Object used to track state during neighbor traversal."""
1177
1267
 
1178
- def __init__(self):
1179
- pass
1180
-
1181
-
1182
- HashGridQuery = hash_grid_query_t
1268
+ _wp_native_name_ = "hash_grid_query_t"
1183
1269
 
1184
1270
 
1185
1271
  # maximum number of dimensions, must match array.h
@@ -1195,9 +1281,13 @@ ARRAY_TYPE_FABRIC_INDEXED = 3
1195
1281
 
1196
1282
  # represents bounds for kernel launch (number of threads across multiple dimensions)
1197
1283
  class launch_bounds_t(ctypes.Structure):
1198
- _fields_ = [("shape", ctypes.c_int32 * LAUNCH_MAX_DIMS), ("ndim", ctypes.c_int32), ("size", ctypes.c_size_t)]
1284
+ _fields_ = (
1285
+ ("shape", ctypes.c_int32 * LAUNCH_MAX_DIMS),
1286
+ ("ndim", ctypes.c_int32),
1287
+ ("size", ctypes.c_size_t),
1288
+ )
1199
1289
 
1200
- def __init__(self, shape: Union[int, Sequence[int]]):
1290
+ def __init__(self, shape: int | Sequence[int]):
1201
1291
  if isinstance(shape, int):
1202
1292
  # 1d launch
1203
1293
  self.ndim = 1
@@ -1219,20 +1309,20 @@ class launch_bounds_t(ctypes.Structure):
1219
1309
 
1220
1310
 
1221
1311
  class shape_t(ctypes.Structure):
1222
- _fields_ = [("dims", ctypes.c_int32 * ARRAY_MAX_DIMS)]
1312
+ _fields_ = (("dims", ctypes.c_int32 * ARRAY_MAX_DIMS),)
1223
1313
 
1224
1314
  def __init__(self):
1225
1315
  pass
1226
1316
 
1227
1317
 
1228
1318
  class array_t(ctypes.Structure):
1229
- _fields_ = [
1319
+ _fields_ = (
1230
1320
  ("data", ctypes.c_uint64),
1231
1321
  ("grad", ctypes.c_uint64),
1232
1322
  ("shape", ctypes.c_int32 * ARRAY_MAX_DIMS),
1233
1323
  ("strides", ctypes.c_int32 * ARRAY_MAX_DIMS),
1234
1324
  ("ndim", ctypes.c_int32),
1235
- ]
1325
+ )
1236
1326
 
1237
1327
  def __init__(self, data=0, grad=0, ndim=0, shape=(0,), strides=(0,)):
1238
1328
  self.data = data
@@ -1268,11 +1358,11 @@ array_t._numpy_dtype_ = {
1268
1358
 
1269
1359
 
1270
1360
  class indexedarray_t(ctypes.Structure):
1271
- _fields_ = [
1361
+ _fields_ = (
1272
1362
  ("data", array_t),
1273
1363
  ("indices", ctypes.c_void_p * ARRAY_MAX_DIMS),
1274
1364
  ("shape", ctypes.c_int32 * ARRAY_MAX_DIMS),
1275
- ]
1365
+ )
1276
1366
 
1277
1367
  def __init__(self, data, indices, shape):
1278
1368
  if data is None:
@@ -1290,17 +1380,44 @@ class indexedarray_t(ctypes.Structure):
1290
1380
  self.shape[i] = shape[i]
1291
1381
 
1292
1382
 
1383
+ class tuple_t:
1384
+ """Used during codegen to store multiple values into a single variable."""
1385
+
1386
+ def __init__(self, types, values):
1387
+ self.types = types
1388
+ self.values = values
1389
+
1390
+
1293
1391
  def type_ctype(dtype):
1294
1392
  if dtype == float:
1295
1393
  return ctypes.c_float
1296
1394
  elif dtype == int:
1297
1395
  return ctypes.c_int32
1396
+ elif dtype == bool:
1397
+ return ctypes.c_bool
1398
+ elif issubclass(dtype, (ctypes.Array, ctypes.Structure)):
1399
+ return dtype
1298
1400
  else:
1299
1401
  # scalar type
1300
1402
  return dtype._type_
1301
1403
 
1302
1404
 
1303
- def type_length(dtype):
1405
+ def type_length(obj):
1406
+ if is_tile(obj):
1407
+ return obj.shape[0]
1408
+ elif is_tuple(obj):
1409
+ return len(obj.types)
1410
+ elif get_origin(obj) is tuple:
1411
+ return len(get_args(obj))
1412
+ elif hasattr(obj, "_shape_"):
1413
+ return obj._shape_[0]
1414
+ elif hasattr(obj, "_length_"):
1415
+ return obj._length_
1416
+
1417
+ return len(obj)
1418
+
1419
+
1420
+ def type_size(dtype):
1304
1421
  if dtype == float or dtype == int or isinstance(dtype, warp.codegen.Struct):
1305
1422
  return 1
1306
1423
  else:
@@ -1406,12 +1523,14 @@ def scalar_short_name(t):
1406
1523
 
1407
1524
 
1408
1525
  # converts any known type to a human readable string, good for error messages, reporting etc
1409
- def type_repr(t):
1526
+ def type_repr(t) -> str:
1410
1527
  if is_array(t):
1411
1528
  if t.device is None:
1412
1529
  # array is used as a type annotation - display ndim instead of shape
1413
1530
  return f"array(ndim={t.ndim}, dtype={type_repr(t.dtype)})"
1414
1531
  return f"array(shape={t.shape}, dtype={type_repr(t.dtype)})"
1532
+ if is_tuple(t):
1533
+ return f"tuple({', '.join(type_repr(x) for x in t.types)})"
1415
1534
  if is_tile(t):
1416
1535
  return f"tile(shape={t.shape}, dtype={type_repr(t.dtype)})"
1417
1536
  if isinstance(t, warp.codegen.Struct):
@@ -1437,6 +1556,12 @@ def type_repr(t):
1437
1556
  return f"matrix(shape=({t._shape_[0]}, {t._shape_[1]}), dtype={type_repr(t._wp_scalar_type_)})"
1438
1557
  if t in scalar_types:
1439
1558
  return t.__name__
1559
+ if t == builtins.bool:
1560
+ return "bool"
1561
+ if t == builtins.float:
1562
+ return "float"
1563
+ if t == builtins.int:
1564
+ return "int"
1440
1565
 
1441
1566
  name = getattr(t, "__name__", None)
1442
1567
  if name is None:
@@ -1479,7 +1604,7 @@ def type_is_transformation(t):
1479
1604
  return getattr(t, "_wp_generic_type_hint_", None) is Transformation
1480
1605
 
1481
1606
 
1482
- value_types = (int, float, builtins.bool) + scalar_and_bool_types
1607
+ value_types = (int, float, builtins.bool, *scalar_and_bool_types)
1483
1608
 
1484
1609
 
1485
1610
  # returns true for all value types (int, float, bool, scalars, vectors, matrices)
@@ -1505,6 +1630,10 @@ def is_array(a) -> builtins.bool:
1505
1630
  return isinstance(a, array_types)
1506
1631
 
1507
1632
 
1633
+ def is_tuple(t) -> builtins.bool:
1634
+ return isinstance(t, tuple_t)
1635
+
1636
+
1508
1637
  def scalars_equal(a, b, match_generic=False):
1509
1638
  # convert to canonical types
1510
1639
  if a == float:
@@ -1546,45 +1675,58 @@ def scalars_equal(a, b, match_generic=False):
1546
1675
  return a == b
1547
1676
 
1548
1677
 
1678
+ def seq_match_ellipsis(a, b, match_generic=False) -> bool:
1679
+ assert a and a[-1] is Ellipsis and len(a) == 2
1680
+
1681
+ # Compare the args against the type being repeated through the ellipsis.
1682
+ repeated_arg = a[0]
1683
+ if not all(types_equal(x, repeated_arg, match_generic=match_generic) for x in b):
1684
+ return False
1685
+
1686
+ return True
1687
+
1688
+
1549
1689
  def types_equal(a, b, match_generic=False):
1550
1690
  if match_generic:
1551
- # Special cases to interpret the types listed in `int_tuple_type_hints`
1552
- # as generic hints that accept any integer types.
1553
- if a in int_tuple_type_hints and isinstance(b, Sequence):
1554
- a_length = int_tuple_type_hints[a]
1555
- if (a_length == -1 or a_length == len(b)) and all(
1556
- scalars_equal(x, Int, match_generic=match_generic) for x in b
1557
- ):
1558
- return True
1559
- if b in int_tuple_type_hints and isinstance(a, Sequence):
1560
- b_length = int_tuple_type_hints[b]
1561
- if (b_length == -1 or b_length == len(a)) and all(
1562
- scalars_equal(x, Int, match_generic=match_generic) for x in a
1563
- ):
1564
- return True
1565
- if a in int_tuple_type_hints and b in int_tuple_type_hints:
1566
- a_length = int_tuple_type_hints[a]
1567
- b_length = int_tuple_type_hints[b]
1568
- if a_length is None or b_length is None or a_length == b_length:
1691
+ a_is_seq = True
1692
+ a_is_tuple = True
1693
+ if is_tuple(a):
1694
+ a = a.types
1695
+ elif get_origin(a) is tuple:
1696
+ a = get_args(a)
1697
+ else:
1698
+ a_is_tuple = False
1699
+ a_is_seq = isinstance(a, Sequence)
1700
+
1701
+ b_is_seq = True
1702
+ b_is_tuple = True
1703
+ if is_tuple(b):
1704
+ b = b.types
1705
+ elif get_origin(b) is tuple:
1706
+ b = get_args(b)
1707
+ else:
1708
+ b_is_tuple = False
1709
+ b_is_seq = isinstance(b, Sequence)
1710
+
1711
+ if a_is_seq and b_is_seq:
1712
+ if (not a and a_is_tuple) or (not b and b_is_tuple):
1713
+ # We have a bare tuple definition like `Tuple`, which matches to anything.
1569
1714
  return True
1570
1715
 
1571
- a_origin = get_origin(a)
1572
- b_origin = get_origin(b)
1573
- if a_origin is tuple and b_origin is tuple:
1574
- a_args = get_args(a)
1575
- b_args = get_args(b)
1576
- if len(a_args) == len(b_args) and all(
1577
- scalars_equal(x, y, match_generic=match_generic) for x, y in zip(a_args, b_args)
1578
- ):
1579
- return True
1580
- elif a_origin is tuple and isinstance(b, Sequence):
1581
- a_args = get_args(a)
1582
- if len(a_args) == len(b) and all(scalars_equal(x, y, match_generic=match_generic) for x, y in zip(a_args, b)):
1583
- return True
1584
- elif b_origin is tuple and isinstance(a, Sequence):
1585
- b_args = get_args(b)
1586
- if len(b_args) == len(a) and all(scalars_equal(x, y, match_generic=match_generic) for x, y in zip(b_args, a)):
1587
- return True
1716
+ a_has_ellipsis = a and a[-1] is Ellipsis
1717
+ b_has_ellipsis = b and b[-1] is Ellipsis
1718
+ if a_has_ellipsis and b_has_ellipsis:
1719
+ # Delegate to comparing all the elements using the standard approach.
1720
+ pass
1721
+ elif a_has_ellipsis:
1722
+ return seq_match_ellipsis(a, b, match_generic=match_generic)
1723
+ elif b_has_ellipsis:
1724
+ return seq_match_ellipsis(b, a, match_generic=match_generic)
1725
+
1726
+ return len(a) == len(b) and all(types_equal(x, y, match_generic=match_generic) for x, y in zip(a, b))
1727
+ elif a_is_seq or b_is_seq:
1728
+ # A sequence can only match to another sequence.
1729
+ return False
1588
1730
 
1589
1731
  # convert to canonical types
1590
1732
  if a == float:
@@ -1621,7 +1763,7 @@ def types_equal(a, b, match_generic=False):
1621
1763
  return scalars_equal(a, b, match_generic)
1622
1764
 
1623
1765
 
1624
- def strides_from_shape(shape: Tuple, dtype):
1766
+ def strides_from_shape(shape: tuple, dtype):
1625
1767
  ndims = len(shape)
1626
1768
  strides = [None] * ndims
1627
1769
 
@@ -1635,7 +1777,7 @@ def strides_from_shape(shape: Tuple, dtype):
1635
1777
  return tuple(strides)
1636
1778
 
1637
1779
 
1638
- def check_array_shape(shape: Tuple):
1780
+ def check_array_shape(shape: tuple):
1639
1781
  """Checks that the size in each dimension is positive and less than 2^31."""
1640
1782
 
1641
1783
  for dim_index, dim_size in enumerate(shape):
@@ -1711,8 +1853,8 @@ class array(Array[DType]):
1711
1853
  ndim (int): The number of array dimensions.
1712
1854
  size (int): The number of items in the array.
1713
1855
  capacity (int): The amount of memory in bytes allocated for this array.
1714
- shape (Tuple[int]): Dimensions of the array.
1715
- strides (Tuple[int]): Number of bytes in each dimension between successive elements of the array.
1856
+ shape (tuple[int]): Dimensions of the array.
1857
+ strides (tuple[int]): Number of bytes in each dimension between successive elements of the array.
1716
1858
  ptr (int): Pointer to underlying memory allocation backing the array.
1717
1859
  device (Device): The device where the array's memory allocation resides.
1718
1860
  pinned (bool): Indicates whether the array was allocated in pinned host memory.
@@ -1726,26 +1868,24 @@ class array(Array[DType]):
1726
1868
  _vars = None
1727
1869
 
1728
1870
  def __new__(cls, *args, **kwargs):
1729
- instance = super(array, cls).__new__(cls)
1871
+ instance = super().__new__(cls)
1730
1872
  instance.deleter = None
1731
1873
  return instance
1732
1874
 
1733
1875
  def __init__(
1734
1876
  self,
1735
- data: Union[List, Tuple, npt.NDArray, None] = None,
1877
+ data: list | tuple | npt.NDArray | None = None,
1736
1878
  dtype: Any = Any,
1737
- shape: Union[int, Tuple[int, ...], List[int], None] = None,
1738
- strides: Optional[Tuple[int, ...]] = None,
1739
- length: Optional[int] = None,
1740
- ptr: Optional[int] = None,
1741
- capacity: Optional[int] = None,
1879
+ shape: int | tuple[int, ...] | list[int] | None = None,
1880
+ strides: tuple[int, ...] | None = None,
1881
+ ptr: int | None = None,
1882
+ capacity: int | None = None,
1742
1883
  device=None,
1743
1884
  pinned: builtins.bool = False,
1744
1885
  copy: builtins.bool = True,
1745
- owner: builtins.bool = False, # deprecated - pass deleter instead
1746
- deleter: Optional[Callable[[int, int], None]] = None,
1747
- ndim: Optional[int] = None,
1748
- grad: Optional[array] = None,
1886
+ deleter: Callable[[int, int], None] | None = None,
1887
+ ndim: int | None = None,
1888
+ grad: array | None = None,
1749
1889
  requires_grad: builtins.bool = False,
1750
1890
  ):
1751
1891
  """Constructs a new Warp array object
@@ -1759,7 +1899,7 @@ class array(Array[DType]):
1759
1899
  allocation should reside on the same device given by the device argument, and the user should set the length
1760
1900
  and dtype parameter appropriately.
1761
1901
 
1762
- If neither ``data`` nor ``ptr`` are specified, the ``shape`` or ``length`` arguments are checked next.
1902
+ If neither ``data`` nor ``ptr`` are specified, the ``shape`` argument is checked next.
1763
1903
  This construction path can be used to create new uninitialized arrays, but users are encouraged to call
1764
1904
  ``wp.empty()``, ``wp.zeros()``, or ``wp.full()`` instead to create new arrays.
1765
1905
 
@@ -1772,14 +1912,11 @@ class array(Array[DType]):
1772
1912
  dtype: One of the available `data types <#data-types>`_, such as :class:`warp.float32`, :class:`warp.mat33`, or a custom `struct <#structs>`_. If dtype is ``Any`` and data is an ndarray, then it will be inferred from the array data type
1773
1913
  shape: Dimensions of the array
1774
1914
  strides: Number of bytes in each dimension between successive elements of the array
1775
- length: Number of elements of the data type (deprecated, users should use ``shape`` argument)
1776
1915
  ptr: Address of an external memory address to alias (``data`` should be ``None``)
1777
1916
  capacity: Maximum size in bytes of the ``ptr`` allocation (``data`` should be ``None``)
1778
1917
  device (Devicelike): Device the array lives on
1779
1918
  copy: Whether the incoming ``data`` will be copied or aliased. Aliasing requires that
1780
1919
  the incoming ``data`` already lives on the ``device`` specified and the data types match.
1781
- owner: Whether the array will try to deallocate the underlying memory when it is deleted
1782
- (deprecated, pass ``deleter`` if you wish to transfer ownership to Warp)
1783
1920
  deleter: Function to be called when the array is deleted, taking two arguments: pointer and size
1784
1921
  requires_grad: Whether or not gradients will be tracked for this array, see :class:`warp.Tape` for details
1785
1922
  grad: The array in which to accumulate gradients in the backward pass. If ``None`` and ``requires_grad`` is ``True``,
@@ -1787,7 +1924,7 @@ class array(Array[DType]):
1787
1924
  pinned: Whether to allocate pinned host memory, which allows asynchronous host–device transfers
1788
1925
  (only applicable with ``device="cpu"``)
1789
1926
 
1790
- """
1927
+ """ # noqa: RUF002
1791
1928
 
1792
1929
  self.ctype = None
1793
1930
 
@@ -1821,23 +1958,6 @@ class array(Array[DType]):
1821
1958
  raise RuntimeError(
1822
1959
  f"Failed to create array with shape {shape}, the maximum number of dimensions is {ARRAY_MAX_DIMS}"
1823
1960
  )
1824
- elif length is not None:
1825
- # backward compatibility
1826
- warp.utils.warn(
1827
- "The 'length' keyword is deprecated and will be removed in a future version. Use 'shape' instead.",
1828
- category=DeprecationWarning,
1829
- stacklevel=2,
1830
- )
1831
- shape = (length,)
1832
-
1833
- if owner:
1834
- warp.utils.warn(
1835
- "The 'owner' keyword in the array initializer is\n"
1836
- "deprecated and will be removed in a future version. It currently has no effect.\n"
1837
- "Pass a function to the 'deleter' keyword instead.",
1838
- category=DeprecationWarning,
1839
- stacklevel=2,
1840
- )
1841
1961
 
1842
1962
  # determine the construction path from the given arguments
1843
1963
  if data is not None:
@@ -1891,65 +2011,131 @@ class array(Array[DType]):
1891
2011
  desc = data.__cuda_array_interface__
1892
2012
  data_shape = desc.get("shape")
1893
2013
  data_strides = desc.get("strides")
1894
- data_dtype = np.dtype(desc.get("typestr"))
2014
+ data_dtype_np = np.dtype(desc.get("typestr"))
2015
+ data_dtype = dtype_from_numpy(data_dtype_np)
1895
2016
  data_ptr = desc.get("data")[0]
1896
2017
 
1897
2018
  if dtype == Any:
1898
- dtype = np_dtype_to_warp_type[data_dtype]
2019
+ dtype = data_dtype
1899
2020
  else:
1900
2021
  # Warn if the data type is compatible with the requested dtype
1901
- if not np_dtype_is_compatible(data_dtype, dtype):
2022
+ if not np_dtype_is_compatible(data_dtype_np, dtype):
1902
2023
  warp.utils.warn(
1903
- f"The input data type {data_dtype} does not appear to be "
2024
+ f"The input data type {data_dtype_np} does not appear to be "
1904
2025
  f"compatible with the requested dtype {dtype}. If "
1905
2026
  "data-type sizes do not match, then this may lead to memory-access violations."
1906
2027
  )
1907
2028
 
1908
2029
  if data_strides is None:
1909
- data_strides = strides_from_shape(data_shape, dtype)
2030
+ data_strides = strides_from_shape(data_shape, data_dtype)
1910
2031
 
1911
2032
  data_ndim = len(data_shape)
1912
2033
 
1913
- # determine whether the input needs reshaping
1914
- target_npshape = None
1915
- if shape is not None:
1916
- target_npshape = (*shape, *dtype_shape)
1917
- elif dtype_ndim > 0:
1918
- # prune inner dimensions of length 1
1919
- while data_ndim > 1 and data_shape[-1] == 1:
1920
- data_shape = data_shape[:-1]
1921
- # if the inner dims don't match exactly, check if the innermost dim is a multiple of type length
1922
- if data_ndim < dtype_ndim or data_shape[-dtype_ndim:] != dtype_shape:
1923
- if data_shape[-1] == dtype._length_:
1924
- target_npshape = (*data_shape[:-1], *dtype_shape)
1925
- elif data_shape[-1] % dtype._length_ == 0:
1926
- target_npshape = (*data_shape[:-1], data_shape[-1] // dtype._length_, *dtype_shape)
2034
+ # determine shape and strides
2035
+ if shape is None:
2036
+ if dtype_ndim == 0:
2037
+ # scalars
2038
+ shape = data_shape
2039
+ strides = data_strides
2040
+ else:
2041
+ # vectors/matrices
2042
+ if data_ndim >= dtype_ndim and data_shape[-dtype_ndim:] == dtype_shape:
2043
+ # the inner shape matches exactly, check inner strides
2044
+ if data_strides[-dtype_ndim:] != strides_from_shape(dtype._shape_, dtype._wp_scalar_type_):
2045
+ raise RuntimeError(
2046
+ f"The inner strides of the input array {data_strides} are not compatible with the requested data type {warp.context.type_str(dtype)}"
2047
+ )
2048
+ shape = data_shape[:-dtype_ndim] or (1,)
2049
+ strides = data_strides[:-dtype_ndim] or (type_size_in_bytes(dtype),)
1927
2050
  else:
1928
- if dtype_ndim == 1:
2051
+ # ensure inner strides are contiguous
2052
+ if data_strides[-1] != type_size_in_bytes(data_dtype):
1929
2053
  raise RuntimeError(
1930
- f"The inner dimensions of the input data are not compatible with the requested vector type {warp.context.type_str(dtype)}: expected an inner dimension that is a multiple of {dtype._length_}"
2054
+ f"The inner strides of the input array {data_strides} are not compatible with the requested data type {warp.context.type_str(dtype)}"
1931
2055
  )
2056
+ # check if the innermost dim is a multiple of type length
2057
+ if data_shape[-1] == dtype._length_:
2058
+ shape = data_shape[:-1] or (1,)
2059
+ strides = data_strides[:-1] or (type_size_in_bytes(dtype),)
2060
+ elif data_shape[-1] % dtype._length_ == 0:
2061
+ shape = (*data_shape[:-1], data_shape[-1] // dtype._length_)
2062
+ strides = (*data_strides[:-1], data_strides[-1] * dtype._length_)
1932
2063
  else:
1933
2064
  raise RuntimeError(
1934
- f"The inner dimensions of the input data are not compatible with the requested matrix type {warp.context.type_str(dtype)}: expected inner dimensions {dtype._shape_} or a multiple of {dtype._length_}"
2065
+ f"The shape of the input array {data_shape} is not compatible with the requested data type {warp.context.type_str(dtype)}"
1935
2066
  )
1936
-
1937
- if target_npshape is None:
1938
- target_npshape = data_shape if shape is None else shape
1939
-
1940
- # determine final shape and strides
1941
- if dtype_ndim > 0:
1942
- # make sure the inner dims are contiguous for vector/matrix types
1943
- scalar_size = type_size_in_bytes(dtype._wp_scalar_type_)
1944
- inner_contiguous = data_strides[-1] == scalar_size
1945
- if inner_contiguous and dtype_ndim > 1:
1946
- inner_contiguous = data_strides[-2] == scalar_size * dtype_shape[-1]
1947
-
1948
- shape = target_npshape[:-dtype_ndim] or (1,)
1949
- strides = data_strides if shape == data_shape else strides_from_shape(shape, dtype)
1950
2067
  else:
1951
- shape = target_npshape or (1,)
1952
- strides = data_strides if shape == data_shape else strides_from_shape(shape, dtype)
2068
+ # a shape was given, reshape if needed
2069
+ if dtype_ndim == 0:
2070
+ # scalars
2071
+ if shape == data_shape:
2072
+ strides = data_strides
2073
+ else:
2074
+ # check if given shape is compatible
2075
+ if math.prod(shape) != math.prod(data_shape):
2076
+ raise RuntimeError(
2077
+ f"The shape of the input array {data_shape} is not compatible with the requested shape {shape}"
2078
+ )
2079
+ # check if data is contiguous
2080
+ if data_strides != strides_from_shape(data_shape, data_dtype):
2081
+ raise RuntimeError(
2082
+ f"The requested shape {shape} is not possible because the input array is not contiguous"
2083
+ )
2084
+ strides = strides_from_shape(shape, dtype)
2085
+ else:
2086
+ # vectors/matrices
2087
+ if data_ndim >= dtype_ndim and data_shape[-dtype_ndim:] == dtype_shape:
2088
+ # the inner shape matches exactly, check outer shape
2089
+ if shape == data_shape[:-dtype_ndim]:
2090
+ strides = data_strides[:-dtype_ndim]
2091
+ else:
2092
+ # check if given shape is compatible
2093
+ if math.prod(shape) != math.prod(data_shape[:-dtype_ndim]):
2094
+ raise RuntimeError(
2095
+ f"The shape of the input array {data_shape} is not compatible with the requested shape {shape} and data type {warp.context.type_str(dtype)}"
2096
+ )
2097
+ # check if data is contiguous
2098
+ if data_strides != strides_from_shape(data_shape, data_dtype):
2099
+ raise RuntimeError(
2100
+ f"The requested shape {shape} is not possible because the input array is not contiguous"
2101
+ )
2102
+ strides = strides_from_shape(shape, dtype)
2103
+ else:
2104
+ # check if the innermost dim is a multiple of type length
2105
+ if data_shape[-1] == dtype._length_:
2106
+ if shape == data_shape[:-1]:
2107
+ strides = data_strides[:-1]
2108
+ else:
2109
+ # check if given shape is compatible
2110
+ if math.prod(shape) != math.prod(data_shape[:-1]):
2111
+ raise RuntimeError(
2112
+ f"The shape of the input array {data_shape} is not compatible with the requested shape {shape} and data type {warp.context.type_str(dtype)}"
2113
+ )
2114
+ # check if data is contiguous
2115
+ if data_strides != strides_from_shape(data_shape, data_dtype):
2116
+ raise RuntimeError(
2117
+ f"The requested shape {shape} is not possible because the input array is not contiguous"
2118
+ )
2119
+ strides = strides_from_shape(shape, dtype)
2120
+ elif data_shape[-1] % dtype._length_ == 0:
2121
+ if shape == (*data_shape[:-1], data_shape[-1] // dtype._length_):
2122
+ strides = (*data_strides[:-1], data_strides[-1] * dtype._length_)
2123
+ else:
2124
+ # check if given shape is compatible
2125
+ if math.prod(shape) != math.prod((*data_shape[:-1], data_shape[-1] // dtype._length_)):
2126
+ raise RuntimeError(
2127
+ f"The shape of the input array {data_shape} is not compatible with the requested shape {shape} and data type {warp.context.type_str(dtype)}"
2128
+ )
2129
+ # check if data is contiguous
2130
+ if data_strides != strides_from_shape(data_shape, data_dtype):
2131
+ raise RuntimeError(
2132
+ f"The requested shape {shape} is not possible because the input array is not contiguous"
2133
+ )
2134
+ strides = strides_from_shape(shape, dtype)
2135
+ else:
2136
+ raise RuntimeError(
2137
+ f"The shape of the input array {data_shape} is not compatible with the requested data type {warp.context.type_str(dtype)} and requested shape {shape}"
2138
+ )
1953
2139
 
1954
2140
  self._init_from_ptr(data_ptr, dtype, shape, strides, None, device, False, None)
1955
2141
 
@@ -2466,9 +2652,7 @@ class array(Array[DType]):
2466
2652
 
2467
2653
  if self.ndim != 2 or other.ndim != 2:
2468
2654
  raise RuntimeError(
2469
- "A has dim = {}, B has dim = {}. If multiplying with @, A and B must have dim = 2.".format(
2470
- self.ndim, other.ndim
2471
- )
2655
+ f"A has dim = {self.ndim}, B has dim = {other.ndim}. If multiplying with @, A and B must have dim = 2."
2472
2656
  )
2473
2657
 
2474
2658
  m = self.shape[0]
@@ -3011,7 +3195,7 @@ def _close_cuda_ipc_handle(ptr, size):
3011
3195
 
3012
3196
 
3013
3197
  def from_ipc_handle(
3014
- handle: bytes, dtype, shape: Tuple[int, ...], strides: Optional[Tuple[int, ...]] = None, device=None
3198
+ handle: bytes, dtype, shape: tuple[int, ...], strides: tuple[int, ...] | None = None, device=None
3015
3199
  ) -> array:
3016
3200
  """Create an array from an IPC handle.
3017
3201
 
@@ -3157,10 +3341,10 @@ class indexedarray(noncontiguous_array_base):
3157
3341
 
3158
3342
  def __init__(
3159
3343
  self,
3160
- data: Optional[array] = None,
3161
- indices: Union[array, List[array], None] = None,
3344
+ data: array | None = None,
3345
+ indices: array | list[array] | None = None,
3162
3346
  dtype=None,
3163
- ndim: Optional[int] = None,
3347
+ ndim: int | None = None,
3164
3348
  ):
3165
3349
  super().__init__(ARRAY_TYPE_INDEXED)
3166
3350
 
@@ -3294,14 +3478,32 @@ def array_type_id(a):
3294
3478
  raise ValueError("Invalid array type")
3295
3479
 
3296
3480
 
3297
- # tile object
3298
- class Tile:
3481
+ class tile(Tile[DType, Shape]):
3482
+ """A Warp tile object.
3483
+
3484
+ Attributes:
3485
+ dtype (DType): The data type of the tile
3486
+ shape (Shape): Dimensions of the tile
3487
+ storage (str): 'register' or 'shared' memory storage
3488
+ layout (str): 'rowmajor' or 'colmajor' layout
3489
+ strides (tuple[int]): Number of tile elements between successive tile entries in each dimension
3490
+ size (int): Total number of tile elements
3491
+ owner (bool): Whether this tile aliases or owns its data
3492
+ """
3493
+
3299
3494
  alignment = 16
3300
3495
 
3301
- def __init__(self, dtype, shape, op=None, storage="register", layout="rowmajor", strides=None, owner=True):
3496
+ def __init__(
3497
+ self,
3498
+ dtype: Any,
3499
+ shape: tuple[int, ...] | list[int],
3500
+ storage: str = "register",
3501
+ layout: str = "rowmajor",
3502
+ strides: tuple[int, ...] | None = None,
3503
+ owner: builtins.bool = True,
3504
+ ):
3302
3505
  self.dtype = type_to_warp(dtype)
3303
3506
  self.shape = shape
3304
- self.op = op
3305
3507
  self.storage = storage
3306
3508
  self.layout = layout
3307
3509
  self.strides = strides
@@ -3349,7 +3551,7 @@ class Tile:
3349
3551
  elif self.storage == "shared":
3350
3552
  if self.owner:
3351
3553
  # allocate new shared memory tile
3352
- return f"wp::tile_alloc_empty<{Var.type_to_ctype(self.dtype)},wp::tile_shape_t<{','.join(map(str, self.shape))}>,{'true' if requires_grad else 'false'}>()"
3554
+ return f"wp::tile_alloc_empty<{Var.type_to_ctype(self.dtype)},wp::tile_shape_t<{','.join(map(str, self.shape))}>,wp::tile_stride_t<{','.join(map(str, self.strides))}>,{'true' if requires_grad else 'false'}>()"
3353
3555
  else:
3354
3556
  # tile will be initialized by another call, e.g.: tile_transpose()
3355
3557
  return "nullptr"
@@ -3361,76 +3563,15 @@ class Tile:
3361
3563
 
3362
3564
  @staticmethod
3363
3565
  def round_up(bytes):
3364
- return ((bytes + Tile.alignment - 1) // Tile.alignment) * Tile.alignment
3566
+ return ((bytes + tile.alignment - 1) // tile.alignment) * tile.alignment
3365
3567
 
3366
3568
  # align tile size to natural boundary, default 16-bytes
3367
3569
  def align(self, bytes):
3368
- return Tile.round_up(bytes)
3369
-
3370
-
3371
- class TileZeros(Tile):
3372
- def __init__(self, dtype, shape, storage="register"):
3373
- Tile.__init__(self, dtype, shape, op="zeros", storage=storage)
3374
-
3375
-
3376
- class TileOnes(Tile):
3377
- def __init__(self, dtype, shape, storage="register"):
3378
- Tile.__init__(self, dtype, shape, op="ones", storage=storage)
3379
-
3380
-
3381
- class TileRange(Tile):
3382
- def __init__(self, dtype, start, stop, step, storage="register"):
3383
- self.start = start
3384
- self.stop = stop
3385
- self.step = step
3386
-
3387
- n = int((stop - start) / step)
3388
-
3389
- Tile.__init__(self, dtype, shape=(n,), op="arange", storage=storage)
3390
-
3391
-
3392
- class TileConstant(Tile):
3393
- def __init__(self, dtype, shape):
3394
- Tile.__init__(self, dtype, shape, op="constant", storage="register")
3395
-
3396
-
3397
- class TileLoad(Tile):
3398
- def __init__(self, array, shape, storage="register"):
3399
- Tile.__init__(self, array.dtype, shape, op="load", storage=storage)
3400
-
3401
-
3402
- class TileUnaryMap(Tile):
3403
- def __init__(self, t, dtype=None, storage="register"):
3404
- Tile.__init__(self, dtype, t.shape, op="unary_map", storage=storage)
3405
-
3406
- # if no output dtype specified then assume it's the same as the first arg
3407
- if self.dtype is None:
3408
- self.dtype = t.dtype
3409
-
3410
- self.t = t
3411
-
3412
-
3413
- class TileBinaryMap(Tile):
3414
- def __init__(self, a, b, dtype=None, storage="register"):
3415
- Tile.__init__(self, dtype, a.shape, op="binary_map", storage=storage)
3416
-
3417
- # if no output dtype specified then assume it's the same as the first arg
3418
- if self.dtype is None:
3419
- self.dtype = a.dtype
3420
-
3421
- self.a = a
3422
- self.b = b
3423
-
3424
-
3425
- class TileShared(Tile):
3426
- def __init__(self, t):
3427
- Tile.__init__(self, t.dtype, t.shape, "shared", storage="shared")
3428
-
3429
- self.t = t
3570
+ return tile.round_up(bytes)
3430
3571
 
3431
3572
 
3432
3573
  def is_tile(t):
3433
- return isinstance(t, Tile)
3574
+ return isinstance(t, tile)
3434
3575
 
3435
3576
 
3436
3577
  bvh_constructor_values = {"sah": 0, "median": 1, "lbvh": 2}
@@ -3438,11 +3579,11 @@ bvh_constructor_values = {"sah": 0, "median": 1, "lbvh": 2}
3438
3579
 
3439
3580
  class Bvh:
3440
3581
  def __new__(cls, *args, **kwargs):
3441
- instance = super(Bvh, cls).__new__(cls)
3582
+ instance = super().__new__(cls)
3442
3583
  instance.id = None
3443
3584
  return instance
3444
3585
 
3445
- def __init__(self, lowers: array, uppers: array, constructor: Optional[str] = None):
3586
+ def __init__(self, lowers: array, uppers: array, constructor: str | None = None):
3446
3587
  """Class representing a bounding volume hierarchy.
3447
3588
 
3448
3589
  Depending on which device the input bounds live, it can be either a CPU tree or a GPU tree.
@@ -3523,14 +3664,14 @@ class Bvh:
3523
3664
  constructor = "sah"
3524
3665
 
3525
3666
  self.id = self.runtime.core.bvh_create_host(
3526
- get_data(lowers), get_data(uppers), int(len(lowers)), bvh_constructor_values[constructor]
3667
+ get_data(lowers), get_data(uppers), len(lowers), bvh_constructor_values[constructor]
3527
3668
  )
3528
3669
  else:
3529
3670
  self.id = self.runtime.core.bvh_create_device(
3530
3671
  self.device.context,
3531
3672
  get_data(lowers),
3532
3673
  get_data(uppers),
3533
- int(len(lowers)),
3674
+ len(lowers),
3534
3675
  bvh_constructor_values[constructor],
3535
3676
  )
3536
3677
 
@@ -3561,14 +3702,14 @@ class Bvh:
3561
3702
  class Mesh:
3562
3703
  from warp.codegen import Var
3563
3704
 
3564
- vars = {
3705
+ vars: ClassVar[dict[str, Var]] = {
3565
3706
  "points": Var("points", array(dtype=vec3)),
3566
3707
  "velocities": Var("velocities", array(dtype=vec3)),
3567
3708
  "indices": Var("indices", array(dtype=int32)),
3568
3709
  }
3569
3710
 
3570
3711
  def __new__(cls, *args, **kwargs):
3571
- instance = super(Mesh, cls).__new__(cls)
3712
+ instance = super().__new__(cls)
3572
3713
  instance.id = None
3573
3714
  return instance
3574
3715
 
@@ -3576,9 +3717,9 @@ class Mesh:
3576
3717
  self,
3577
3718
  points: array,
3578
3719
  indices: array,
3579
- velocities: Optional[array] = None,
3580
- support_winding_number: bool = False,
3581
- bvh_constructor: Optional[str] = None,
3720
+ velocities: array | None = None,
3721
+ support_winding_number: builtins.bool = False,
3722
+ bvh_constructor: str | None = None,
3582
3723
  ):
3583
3724
  """Class representing a triangle mesh.
3584
3725
 
@@ -3640,7 +3781,7 @@ class Mesh:
3640
3781
  points.__ctype__(),
3641
3782
  velocities.__ctype__() if velocities else array().__ctype__(),
3642
3783
  indices.__ctype__(),
3643
- int(len(points)),
3784
+ len(points),
3644
3785
  int(indices.size / 3),
3645
3786
  int(support_winding_number),
3646
3787
  bvh_constructor_values[bvh_constructor],
@@ -3651,7 +3792,7 @@ class Mesh:
3651
3792
  points.__ctype__(),
3652
3793
  velocities.__ctype__() if velocities else array().__ctype__(),
3653
3794
  indices.__ctype__(),
3654
- int(len(points)),
3795
+ len(points),
3655
3796
  int(indices.size / 3),
3656
3797
  int(support_winding_number),
3657
3798
  bvh_constructor_values[bvh_constructor],
@@ -3750,7 +3891,7 @@ class Volume:
3750
3891
  LINEAR = constant(1)
3751
3892
 
3752
3893
  def __new__(cls, *args, **kwargs):
3753
- instance = super(Volume, cls).__new__(cls)
3894
+ instance = super().__new__(cls)
3754
3895
  instance.id = None
3755
3896
  return instance
3756
3897
 
@@ -3799,7 +3940,7 @@ class Volume:
3799
3940
  buf = ctypes.c_void_p(0)
3800
3941
  size = ctypes.c_uint64(0)
3801
3942
  self.runtime.core.volume_get_buffer_info(self.id, ctypes.byref(buf), ctypes.byref(size))
3802
- return array(ptr=buf.value, dtype=uint8, shape=size.value, device=self.device, owner=False)
3943
+ return array(ptr=buf.value, dtype=uint8, shape=size.value, device=self.device)
3803
3944
 
3804
3945
  def get_tile_count(self) -> int:
3805
3946
  """Return the number of tiles (NanoVDB leaf nodes) of the volume."""
@@ -3811,7 +3952,7 @@ class Volume:
3811
3952
  self.runtime.core.volume_get_tile_and_voxel_count(self.id, ctypes.byref(tile_count), ctypes.byref(voxel_count))
3812
3953
  return tile_count.value
3813
3954
 
3814
- def get_tiles(self, out: Optional[array] = None) -> array:
3955
+ def get_tiles(self, out: array | None = None) -> array:
3815
3956
  """Return the integer coordinates of all allocated tiles for this volume.
3816
3957
 
3817
3958
  Args:
@@ -3851,7 +3992,7 @@ class Volume:
3851
3992
  self.runtime.core.volume_get_tile_and_voxel_count(self.id, ctypes.byref(tile_count), ctypes.byref(voxel_count))
3852
3993
  return voxel_count.value
3853
3994
 
3854
- def get_voxels(self, out: Optional[array] = None) -> array:
3995
+ def get_voxels(self, out: array | None = None) -> array:
3855
3996
  """Return the integer coordinates of all allocated voxels for this volume.
3856
3997
 
3857
3998
  Args:
@@ -3880,7 +4021,7 @@ class Volume:
3880
4021
 
3881
4022
  return out
3882
4023
 
3883
- def get_voxel_size(self) -> Tuple[float, float, float]:
4024
+ def get_voxel_size(self) -> tuple[float, float, float]:
3884
4025
  """Return the voxel size, i.e, world coordinates of voxel's diagonal vector"""
3885
4026
 
3886
4027
  if self.id == 0:
@@ -3943,7 +4084,7 @@ class Volume:
3943
4084
  mat33f.from_buffer_copy(transform_buffer),
3944
4085
  )
3945
4086
 
3946
- _nvdb_type_to_dtype = {
4087
+ _nvdb_type_to_dtype: ClassVar[dict[str, type]] = {
3947
4088
  "float": float32,
3948
4089
  "double": float64,
3949
4090
  "int16": int16,
@@ -4061,7 +4202,7 @@ class Volume:
4061
4202
  if type_size_in_bytes(dtype) != value_size:
4062
4203
  raise RuntimeError(f"Cannot cast feature data of size {value_size} to array dtype {type_repr(dtype)}")
4063
4204
 
4064
- return array(ptr=info.ptr, dtype=dtype, shape=value_count, device=self.device, owner=False)
4205
+ return array(ptr=info.ptr, dtype=dtype, shape=value_count, device=self.device)
4065
4206
 
4066
4207
  @classmethod
4067
4208
  def load_from_nvdb(cls, file_or_buffer, device=None) -> Volume:
@@ -4141,15 +4282,15 @@ class Volume:
4141
4282
  codec_dict = {"none": 0, "zip": 1, "blosc": 2}
4142
4283
 
4143
4284
  class FileHeader(ctypes.Structure):
4144
- _fields_ = [
4285
+ _fields_ = (
4145
4286
  ("magic", ctypes.c_uint64),
4146
4287
  ("version", ctypes.c_uint32),
4147
4288
  ("gridCount", ctypes.c_uint16),
4148
4289
  ("codec", ctypes.c_uint16),
4149
- ]
4290
+ )
4150
4291
 
4151
4292
  class FileMetaData(ctypes.Structure):
4152
- _fields_ = [
4293
+ _fields_ = (
4153
4294
  ("gridSize", ctypes.c_uint64),
4154
4295
  ("fileSize", ctypes.c_uint64),
4155
4296
  ("nameKey", ctypes.c_uint64),
@@ -4165,10 +4306,10 @@ class Volume:
4165
4306
  ("codec", ctypes.c_uint16),
4166
4307
  ("padding", ctypes.c_uint16),
4167
4308
  ("version", ctypes.c_uint32),
4168
- ]
4309
+ )
4169
4310
 
4170
4311
  class GridData(ctypes.Structure):
4171
- _fields_ = [
4312
+ _fields_ = (
4172
4313
  ("magic", ctypes.c_uint64),
4173
4314
  ("checksum", ctypes.c_uint64),
4174
4315
  ("version", ctypes.c_uint32),
@@ -4187,7 +4328,7 @@ class Volume:
4187
4328
  ("data0", ctypes.c_uint32),
4188
4329
  ("data1", ctypes.c_uint64),
4189
4330
  ("data2", ctypes.c_uint64),
4190
- ]
4331
+ )
4191
4332
 
4192
4333
  NVDB_MAGIC = 0x304244566F6E614E
4193
4334
  NVDB_VERSION = 32 << 21 | 3 << 10 | 3
@@ -4293,7 +4434,7 @@ class Volume:
4293
4434
  "A warp Volume has already been created for this grid, aliasing it more than once is not possible."
4294
4435
  )
4295
4436
 
4296
- data_array = array(ptr=grid_ptr, dtype=uint8, shape=buffer_size, owner=False, device=device)
4437
+ data_array = array(ptr=grid_ptr, dtype=uint8, shape=buffer_size, device=device)
4297
4438
 
4298
4439
  return cls(data_array, copy=False)
4299
4440
 
@@ -4410,8 +4551,8 @@ class Volume:
4410
4551
  @classmethod
4411
4552
  def allocate(
4412
4553
  cls,
4413
- min: List[int],
4414
- max: List[int],
4554
+ min: list[int],
4555
+ max: list[int],
4415
4556
  voxel_size: float,
4416
4557
  bg_value=0.0,
4417
4558
  translation=(0.0, 0.0, 0.0),
@@ -4459,7 +4600,7 @@ class Volume:
4459
4600
 
4460
4601
  @staticmethod
4461
4602
  def _fill_transform_buffers(
4462
- voxel_size: Union[float, List[float]],
4603
+ voxel_size: float | list[float],
4463
4604
  translation,
4464
4605
  transform,
4465
4606
  ):
@@ -4483,18 +4624,13 @@ class Volume:
4483
4624
 
4484
4625
  # nanovdb types for which we instantiate the grid builder
4485
4626
  # Should be in sync with WP_VOLUME_BUILDER_INSTANTIATE_TYPES in volume_builder.h
4486
- _supported_allocation_types = [
4487
- "int32",
4488
- "float",
4489
- "Vec3f",
4490
- "Vec4f",
4491
- ]
4627
+ _supported_allocation_types = ("int32", "float", "Vec3f", "Vec4f")
4492
4628
 
4493
4629
  @classmethod
4494
4630
  def allocate_by_tiles(
4495
4631
  cls,
4496
4632
  tile_points: array,
4497
- voxel_size: Union[float, List[float]] = None,
4633
+ voxel_size: float | list[float] | None = None,
4498
4634
  bg_value=0.0,
4499
4635
  translation=(0.0, 0.0, 0.0),
4500
4636
  device=None,
@@ -4602,7 +4738,7 @@ class Volume:
4602
4738
  def allocate_by_voxels(
4603
4739
  cls,
4604
4740
  voxel_points: array,
4605
- voxel_size: Union[float, List[float]] = None,
4741
+ voxel_size: float | list[float] | None = None,
4606
4742
  translation=(0.0, 0.0, 0.0),
4607
4743
  device=None,
4608
4744
  transform=None,
@@ -4658,20 +4794,20 @@ class Volume:
4658
4794
  return volume
4659
4795
 
4660
4796
 
4661
- def _is_contiguous_vec_like_array(array, vec_length: int, scalar_types: Tuple[type]) -> bool:
4797
+ def _is_contiguous_vec_like_array(array, vec_length: int, scalar_types: tuple[type]) -> builtins.bool:
4662
4798
  if not (is_array(array) and array.is_contiguous):
4663
4799
  return False
4664
4800
  if type_scalar_type(array.dtype) not in scalar_types:
4665
4801
  return False
4666
- return (array.ndim == 1 and type_length(array.dtype) == vec_length) or (
4667
- array.ndim == 2 and array.shape[1] == vec_length and type_length(array.dtype) == 1
4802
+ return (array.ndim == 1 and type_size(array.dtype) == vec_length) or (
4803
+ array.ndim == 2 and array.shape[1] == vec_length and type_size(array.dtype) == 1
4668
4804
  )
4669
4805
 
4670
4806
 
4671
4807
  # definition just for kernel type (cannot be a parameter), see mesh.h
4672
- # NOTE: its layout must match the corresponding struct defined in C.
4808
+ # NOTE: its layout must match the mesh_query_point_t struct defined in C.
4673
4809
  # NOTE: it needs to be defined after `indexedarray` to workaround a circular import issue.
4674
- class mesh_query_point_t:
4810
+ class MeshQueryPoint:
4675
4811
  """Output for the mesh query point functions.
4676
4812
 
4677
4813
  Attributes:
@@ -4691,7 +4827,9 @@ class mesh_query_point_t:
4691
4827
 
4692
4828
  from warp.codegen import Var
4693
4829
 
4694
- vars = {
4830
+ _wp_native_name_ = "mesh_query_point_t"
4831
+
4832
+ vars: ClassVar[dict[str, Var]] = {
4695
4833
  "result": Var("result", bool),
4696
4834
  "sign": Var("sign", float32),
4697
4835
  "face": Var("face", int32),
@@ -4700,12 +4838,9 @@ class mesh_query_point_t:
4700
4838
  }
4701
4839
 
4702
4840
 
4703
- MeshQueryPoint = mesh_query_point_t
4704
-
4705
-
4706
4841
  # definition just for kernel type (cannot be a parameter), see mesh.h
4707
- # NOTE: its layout must match the corresponding struct defined in C.
4708
- class mesh_query_ray_t:
4842
+ # NOTE: its layout must match the mesh_query_ray_t struct defined in C.
4843
+ class MeshQueryRay:
4709
4844
  """Output for the mesh query ray functions.
4710
4845
 
4711
4846
  Attributes:
@@ -4723,7 +4858,9 @@ class mesh_query_ray_t:
4723
4858
 
4724
4859
  from warp.codegen import Var
4725
4860
 
4726
- vars = {
4861
+ _wp_native_name_ = "mesh_query_ray_t"
4862
+
4863
+ vars: ClassVar[dict[str, Var]] = {
4727
4864
  "result": Var("result", bool),
4728
4865
  "sign": Var("sign", float32),
4729
4866
  "face": Var("face", int32),
@@ -4734,9 +4871,6 @@ class mesh_query_ray_t:
4734
4871
  }
4735
4872
 
4736
4873
 
4737
- MeshQueryRay = mesh_query_ray_t
4738
-
4739
-
4740
4874
  def matmul(
4741
4875
  a: array2d,
4742
4876
  b: array2d,
@@ -4863,7 +4997,7 @@ def adj_batched_matmul(
4863
4997
 
4864
4998
  class HashGrid:
4865
4999
  def __new__(cls, *args, **kwargs):
4866
- instance = super(HashGrid, cls).__new__(cls)
5000
+ instance = super().__new__(cls)
4867
5001
  instance.id = None
4868
5002
  return instance
4869
5003
 
@@ -4938,7 +5072,7 @@ class HashGrid:
4938
5072
 
4939
5073
  class MarchingCubes:
4940
5074
  def __new__(cls, *args, **kwargs):
4941
- instance = super(MarchingCubes, cls).__new__(cls)
5075
+ instance = super().__new__(cls)
4942
5076
  instance.id = None
4943
5077
  return instance
4944
5078
 
@@ -5101,6 +5235,9 @@ def type_is_generic(t):
5101
5235
  if is_array(t):
5102
5236
  return type_is_generic(t.dtype)
5103
5237
 
5238
+ if get_origin(t) is tuple:
5239
+ return True
5240
+
5104
5241
  if hasattr(t, "_wp_scalar_type_"):
5105
5242
  # vector/matrix type, check if dtype is generic
5106
5243
  if type_is_generic(t._wp_scalar_type_):
@@ -5170,8 +5307,18 @@ def type_matches_template(arg_type, template_type):
5170
5307
  return True
5171
5308
 
5172
5309
 
5173
- def infer_argument_types(args, template_types, arg_names=None):
5174
- """Resolve argument types with the given list of template types."""
5310
+ def infer_argument_types(args: list[Any], template_types, arg_names: list[str] | None = None) -> list[type]:
5311
+ """Resolve argument types with the given list of template types.
5312
+
5313
+ Args:
5314
+ args: List of arguments to infer types for.
5315
+ template_types: List of template types to match against.
5316
+ arg_names: List of argument names to use for error messages.
5317
+
5318
+ Raises:
5319
+ RuntimeError: Number of arguments must match number of template types.
5320
+ TypeError: Unable to infer the type of an argument.
5321
+ """
5175
5322
 
5176
5323
  if len(args) != len(template_types):
5177
5324
  raise RuntimeError("Number of arguments must match number of template types.")
@@ -5195,10 +5342,6 @@ def infer_argument_types(args, template_types, arg_names=None):
5195
5342
  elif issubclass(arg_type, warp.codegen.StructInstance):
5196
5343
  # a struct
5197
5344
  arg_types.append(arg._cls)
5198
- # elif arg_type in [warp.types.launch_bounds_t, warp.types.shape_t, warp.types.range_t]:
5199
- # arg_types.append(arg_type)
5200
- # elif arg_type in [warp.hash_grid_query_t, warp.mesh_query_aabb_t, warp.mesh_query_point_t, warp.mesh_query_ray_t, warp.bvh_query_t]:
5201
- # arg_types.append(arg_type)
5202
5345
  elif arg is None:
5203
5346
  # allow passing None for arrays
5204
5347
  t = template_types[i]
@@ -5233,11 +5376,11 @@ simple_type_codes = {
5233
5376
  shape_t: "sh",
5234
5377
  range_t: "rg",
5235
5378
  launch_bounds_t: "lb",
5236
- hash_grid_query_t: "hgq",
5237
- mesh_query_aabb_t: "mqa",
5238
- mesh_query_point_t: "mqp",
5239
- mesh_query_ray_t: "mqr",
5240
- bvh_query_t: "bvhq",
5379
+ HashGridQuery: "hgq",
5380
+ MeshQueryAABB: "mqa",
5381
+ MeshQueryPoint: "mqp",
5382
+ MeshQueryRay: "mqr",
5383
+ BvhQuery: "bvhq",
5241
5384
  }
5242
5385
 
5243
5386
 
@@ -5287,8 +5430,17 @@ def get_type_code(arg_type: type) -> str:
5287
5430
  return f"fa{arg_type.ndim}{get_type_code(arg_type.dtype)}"
5288
5431
  elif isinstance(arg_type, indexedfabricarray):
5289
5432
  return f"ifa{arg_type.ndim}{get_type_code(arg_type.dtype)}"
5433
+ elif get_origin(arg_type) is tuple:
5434
+ arg_types = get_args(arg_type)
5435
+ return f"tpl{len(arg_types)}{''.join(get_type_code(x) for x in arg_types)}"
5436
+ elif isinstance(arg_type, tuple_t):
5437
+ return f"tplt{len(arg_type.types)}{''.join(get_type_code(x) for x in arg_type.types)}"
5290
5438
  elif isinstance(arg_type, warp.codegen.Struct):
5291
5439
  return arg_type.native_name
5440
+ elif isinstance(arg_type, tile):
5441
+ shape_string = "".join(str(num) for num in arg_type.shape)
5442
+ storage = "s" if arg_type.storage == "shared" else "r"
5443
+ return f"t{storage}{shape_string}{get_type_code(arg_type.dtype)}"
5292
5444
  elif arg_type == Scalar:
5293
5445
  # generic scalar type
5294
5446
  return "s?"
@@ -5301,12 +5453,14 @@ def get_type_code(arg_type: type) -> str:
5301
5453
  elif isinstance(arg_type, Callable):
5302
5454
  # TODO: elaborate on Callable type?
5303
5455
  return "c"
5456
+ elif arg_type is Ellipsis:
5457
+ return "?"
5304
5458
  else:
5305
5459
  raise TypeError(f"Unrecognized type '{arg_type}'")
5306
5460
 
5307
5461
 
5308
- def get_signature(arg_types: List[type], func_name: Optional[str] = None, arg_names: Optional[List[str]] = None) -> str:
5309
- type_codes: List[str] = []
5462
+ def get_signature(arg_types: list[type], func_name: str | None = None, arg_names: list[str] | None = None) -> str:
5463
+ type_codes: list[str] = []
5310
5464
  for i, arg_type in enumerate(arg_types):
5311
5465
  try:
5312
5466
  type_codes.append(get_type_code(arg_type))