warp-lang 1.7.2rc1__py3-none-manylinux_2_34_aarch64.whl → 1.8.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (181) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +125 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +257 -101
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +657 -223
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/optim/example_drone.py +1 -1
  37. warp/examples/sim/example_cloth.py +1 -1
  38. warp/examples/sim/example_cloth_self_contact.py +48 -54
  39. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  40. warp/examples/tile/example_tile_cholesky.py +2 -1
  41. warp/examples/tile/example_tile_convolution.py +1 -1
  42. warp/examples/tile/example_tile_filtering.py +1 -1
  43. warp/examples/tile/example_tile_matmul.py +1 -1
  44. warp/examples/tile/example_tile_mlp.py +2 -0
  45. warp/fabric.py +7 -7
  46. warp/fem/__init__.py +5 -0
  47. warp/fem/adaptivity.py +1 -1
  48. warp/fem/cache.py +152 -63
  49. warp/fem/dirichlet.py +2 -2
  50. warp/fem/domain.py +136 -6
  51. warp/fem/field/field.py +141 -99
  52. warp/fem/field/nodal_field.py +85 -39
  53. warp/fem/field/virtual.py +97 -52
  54. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  55. warp/fem/geometry/closest_point.py +13 -0
  56. warp/fem/geometry/deformed_geometry.py +102 -40
  57. warp/fem/geometry/element.py +56 -2
  58. warp/fem/geometry/geometry.py +323 -22
  59. warp/fem/geometry/grid_2d.py +157 -62
  60. warp/fem/geometry/grid_3d.py +116 -20
  61. warp/fem/geometry/hexmesh.py +86 -20
  62. warp/fem/geometry/nanogrid.py +166 -86
  63. warp/fem/geometry/partition.py +59 -25
  64. warp/fem/geometry/quadmesh.py +86 -135
  65. warp/fem/geometry/tetmesh.py +47 -119
  66. warp/fem/geometry/trimesh.py +77 -270
  67. warp/fem/integrate.py +107 -52
  68. warp/fem/linalg.py +25 -58
  69. warp/fem/operator.py +124 -27
  70. warp/fem/quadrature/pic_quadrature.py +36 -14
  71. warp/fem/quadrature/quadrature.py +40 -16
  72. warp/fem/space/__init__.py +1 -1
  73. warp/fem/space/basis_function_space.py +66 -46
  74. warp/fem/space/basis_space.py +17 -4
  75. warp/fem/space/dof_mapper.py +1 -1
  76. warp/fem/space/function_space.py +2 -2
  77. warp/fem/space/grid_2d_function_space.py +4 -1
  78. warp/fem/space/hexmesh_function_space.py +4 -2
  79. warp/fem/space/nanogrid_function_space.py +3 -1
  80. warp/fem/space/partition.py +11 -2
  81. warp/fem/space/quadmesh_function_space.py +4 -1
  82. warp/fem/space/restriction.py +5 -2
  83. warp/fem/space/shape/__init__.py +10 -8
  84. warp/fem/space/tetmesh_function_space.py +4 -1
  85. warp/fem/space/topology.py +52 -21
  86. warp/fem/space/trimesh_function_space.py +4 -1
  87. warp/fem/utils.py +53 -8
  88. warp/jax.py +1 -2
  89. warp/jax_experimental/ffi.py +12 -17
  90. warp/jax_experimental/xla_ffi.py +37 -24
  91. warp/math.py +171 -1
  92. warp/native/array.h +99 -0
  93. warp/native/builtin.h +174 -31
  94. warp/native/coloring.cpp +1 -1
  95. warp/native/exports.h +118 -63
  96. warp/native/intersect.h +3 -3
  97. warp/native/mat.h +5 -10
  98. warp/native/mathdx.cpp +11 -5
  99. warp/native/matnn.h +1 -123
  100. warp/native/quat.h +28 -4
  101. warp/native/sparse.cpp +121 -258
  102. warp/native/sparse.cu +181 -274
  103. warp/native/spatial.h +305 -17
  104. warp/native/tile.h +583 -72
  105. warp/native/tile_radix_sort.h +1108 -0
  106. warp/native/tile_reduce.h +237 -2
  107. warp/native/tile_scan.h +240 -0
  108. warp/native/tuple.h +189 -0
  109. warp/native/vec.h +6 -16
  110. warp/native/warp.cpp +36 -4
  111. warp/native/warp.cu +574 -51
  112. warp/native/warp.h +47 -74
  113. warp/optim/linear.py +5 -1
  114. warp/paddle.py +7 -8
  115. warp/py.typed +0 -0
  116. warp/render/render_opengl.py +58 -29
  117. warp/render/render_usd.py +124 -61
  118. warp/sim/__init__.py +9 -0
  119. warp/sim/collide.py +252 -78
  120. warp/sim/graph_coloring.py +8 -1
  121. warp/sim/import_mjcf.py +4 -3
  122. warp/sim/import_usd.py +11 -7
  123. warp/sim/integrator.py +5 -2
  124. warp/sim/integrator_euler.py +1 -1
  125. warp/sim/integrator_featherstone.py +1 -1
  126. warp/sim/integrator_vbd.py +751 -320
  127. warp/sim/integrator_xpbd.py +1 -1
  128. warp/sim/model.py +265 -260
  129. warp/sim/utils.py +10 -7
  130. warp/sparse.py +303 -166
  131. warp/tape.py +52 -51
  132. warp/tests/cuda/test_conditional_captures.py +1046 -0
  133. warp/tests/cuda/test_streams.py +1 -1
  134. warp/tests/geometry/test_volume.py +2 -2
  135. warp/tests/interop/test_dlpack.py +9 -9
  136. warp/tests/interop/test_jax.py +0 -1
  137. warp/tests/run_coverage_serial.py +1 -1
  138. warp/tests/sim/disabled_kinematics.py +2 -2
  139. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  140. warp/tests/sim/test_collision.py +159 -51
  141. warp/tests/sim/test_coloring.py +15 -1
  142. warp/tests/test_array.py +254 -2
  143. warp/tests/test_array_reduce.py +2 -2
  144. warp/tests/test_atomic_cas.py +299 -0
  145. warp/tests/test_codegen.py +142 -19
  146. warp/tests/test_conditional.py +47 -1
  147. warp/tests/test_ctypes.py +0 -20
  148. warp/tests/test_devices.py +8 -0
  149. warp/tests/test_fabricarray.py +4 -2
  150. warp/tests/test_fem.py +58 -25
  151. warp/tests/test_func.py +42 -1
  152. warp/tests/test_grad.py +1 -1
  153. warp/tests/test_lerp.py +1 -3
  154. warp/tests/test_map.py +481 -0
  155. warp/tests/test_mat.py +1 -24
  156. warp/tests/test_quat.py +6 -15
  157. warp/tests/test_rounding.py +10 -38
  158. warp/tests/test_runlength_encode.py +7 -7
  159. warp/tests/test_smoothstep.py +1 -1
  160. warp/tests/test_sparse.py +51 -2
  161. warp/tests/test_spatial.py +507 -1
  162. warp/tests/test_struct.py +2 -2
  163. warp/tests/test_tuple.py +265 -0
  164. warp/tests/test_types.py +2 -2
  165. warp/tests/test_utils.py +24 -18
  166. warp/tests/tile/test_tile.py +420 -1
  167. warp/tests/tile/test_tile_mathdx.py +518 -14
  168. warp/tests/tile/test_tile_reduce.py +213 -0
  169. warp/tests/tile/test_tile_shared_memory.py +130 -1
  170. warp/tests/tile/test_tile_sort.py +117 -0
  171. warp/tests/unittest_suites.py +4 -6
  172. warp/types.py +462 -308
  173. warp/utils.py +647 -86
  174. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  175. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
  176. warp/stubs.py +0 -3381
  177. warp/tests/sim/test_xpbd.py +0 -399
  178. warp/tests/test_mlp.py +0 -282
  179. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  180. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  181. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/utils.py CHANGED
@@ -21,7 +21,8 @@ import os
21
21
  import sys
22
22
  import time
23
23
  import warnings
24
- from typing import Any, Callable, Dict, List, Optional, Union
24
+ from types import ModuleType
25
+ from typing import Any, Callable
25
26
 
26
27
  import numpy as np
27
28
 
@@ -29,6 +30,7 @@ import warp as wp
29
30
  import warp.context
30
31
  import warp.types
31
32
  from warp.context import Devicelike
33
+ from warp.types import Array, DType, type_repr, types_equal
32
34
 
33
35
  warnings_seen = set()
34
36
 
@@ -52,7 +54,7 @@ def warp_showwarning(message, category, filename, lineno, file=None, line=None):
52
54
 
53
55
  if line:
54
56
  line = line.strip()
55
- s += " %s\n" % line
57
+ s += f" {line}\n"
56
58
  else:
57
59
  # simple warning
58
60
  s = f"Warp {category.__name__}: {message}\n"
@@ -96,14 +98,31 @@ def quat_between_vectors(a: wp.vec3, b: wp.vec3) -> wp.quat:
96
98
 
97
99
 
98
100
  def array_scan(in_array, out_array, inclusive=True):
101
+ """Perform a scan (prefix sum) operation on an array.
102
+
103
+ This function computes the inclusive or exclusive scan of the input array and stores the result in the output array.
104
+ The scan operation computes a running sum of elements in the array.
105
+
106
+ Args:
107
+ in_array (wp.array): Input array to scan. Must be of type int32 or float32.
108
+ out_array (wp.array): Output array to store scan results. Must match input array type and size.
109
+ inclusive (bool, optional): If True, performs an inclusive scan (includes current element in sum).
110
+ If False, performs an exclusive scan (excludes current element). Defaults to True.
111
+
112
+ Raises:
113
+ RuntimeError: If array storage devices don't match, if storage size is insufficient, or if data types are unsupported.
114
+ """
115
+
99
116
  if in_array.device != out_array.device:
100
- raise RuntimeError("Array storage devices do not match")
117
+ raise RuntimeError(f"In and out array storage devices do not match ({in_array.device} vs {out_array.device})")
101
118
 
102
119
  if in_array.size != out_array.size:
103
- raise RuntimeError("Array storage sizes do not match")
120
+ raise RuntimeError(f"In and out array storage sizes do not match ({in_array.size} vs {out_array.size})")
104
121
 
105
- if in_array.dtype != out_array.dtype:
106
- raise RuntimeError("Array data types do not match")
122
+ if not types_equal(in_array.dtype, out_array.dtype):
123
+ raise RuntimeError(
124
+ f"In and out array data types do not match ({type_repr(in_array.dtype)} vs {type_repr(out_array.dtype)})"
125
+ )
107
126
 
108
127
  if in_array.size == 0:
109
128
  return
@@ -116,25 +135,39 @@ def array_scan(in_array, out_array, inclusive=True):
116
135
  elif in_array.dtype == wp.float32:
117
136
  runtime.core.array_scan_float_host(in_array.ptr, out_array.ptr, in_array.size, inclusive)
118
137
  else:
119
- raise RuntimeError("Unsupported data type")
138
+ raise RuntimeError(f"Unsupported data type: {type_repr(in_array.dtype)}")
120
139
  elif in_array.device.is_cuda:
121
140
  if in_array.dtype == wp.int32:
122
141
  runtime.core.array_scan_int_device(in_array.ptr, out_array.ptr, in_array.size, inclusive)
123
142
  elif in_array.dtype == wp.float32:
124
143
  runtime.core.array_scan_float_device(in_array.ptr, out_array.ptr, in_array.size, inclusive)
125
144
  else:
126
- raise RuntimeError("Unsupported data type")
145
+ raise RuntimeError(f"Unsupported data type: {type_repr(in_array.dtype)}")
127
146
 
128
147
 
129
148
  def radix_sort_pairs(keys, values, count: int):
149
+ """Sort key-value pairs using radix sort.
150
+
151
+ This function sorts pairs of arrays based on the keys array, maintaining the key-value
152
+ relationship. The sort is stable and operates in linear time.
153
+ The `keys` and `values` arrays must be large enough to accommodate 2*`count` elements.
154
+
155
+ Args:
156
+ keys (wp.array): Array of keys to sort. Must be of type int32, float32, or int64.
157
+ values (wp.array): Array of values to sort along with keys. Must be of type int32.
158
+ count (int): Number of elements to sort.
159
+
160
+ Raises:
161
+ RuntimeError: If array storage devices don't match, if storage size is insufficient, or if data types are unsupported.
162
+ """
130
163
  if keys.device != values.device:
131
- raise RuntimeError("Array storage devices do not match")
164
+ raise RuntimeError(f"Keys and values array storage devices do not match ({keys.device} vs {values.device})")
132
165
 
133
166
  if count == 0:
134
167
  return
135
168
 
136
169
  if keys.size < 2 * count or values.size < 2 * count:
137
- raise RuntimeError("Array storage must be large enough to contain 2*count elements")
170
+ raise RuntimeError("Keys and values array storage must be large enough to contain 2*count elements")
138
171
 
139
172
  from warp.context import runtime
140
173
 
@@ -146,7 +179,9 @@ def radix_sort_pairs(keys, values, count: int):
146
179
  elif keys.dtype == wp.int64 and values.dtype == wp.int32:
147
180
  runtime.core.radix_sort_pairs_int64_host(keys.ptr, values.ptr, count)
148
181
  else:
149
- raise RuntimeError("Unsupported data type")
182
+ raise RuntimeError(
183
+ f"Unsupported keys and values data types: {type_repr(keys.dtype)}, {type_repr(values.dtype)}"
184
+ )
150
185
  elif keys.device.is_cuda:
151
186
  if keys.dtype == wp.int32 and values.dtype == wp.int32:
152
187
  runtime.core.radix_sort_pairs_int_device(keys.ptr, values.ptr, count)
@@ -155,7 +190,9 @@ def radix_sort_pairs(keys, values, count: int):
155
190
  elif keys.dtype == wp.int64 and values.dtype == wp.int32:
156
191
  runtime.core.radix_sort_pairs_int64_device(keys.ptr, values.ptr, count)
157
192
  else:
158
- raise RuntimeError("Unsupported data type")
193
+ raise RuntimeError(
194
+ f"Unsupported keys and values data types: {type_repr(keys.dtype)}, {type_repr(values.dtype)}"
195
+ )
159
196
 
160
197
 
161
198
  def segmented_sort_pairs(
@@ -169,6 +206,7 @@ def segmented_sort_pairs(
169
206
 
170
207
  This function performs a segmented sort of key-value pairs, where the sorting is done independently within each segment.
171
208
  The segments are defined by their start and optionally end indices.
209
+ The `keys` and `values` arrays must be large enough to accommodate 2*`count` elements.
172
210
 
173
211
  Args:
174
212
  keys: Array of keys to sort. Must be of type int32 or float32.
@@ -187,7 +225,7 @@ def segmented_sort_pairs(
187
225
  if segment_start_indices is not of type int32, or if data types are unsupported.
188
226
  """
189
227
  if keys.device != values.device:
190
- raise RuntimeError("Array storage devices do not match")
228
+ raise RuntimeError(f"Array storage devices do not match ({keys.device} vs {values.device})")
191
229
 
192
230
  if count == 0:
193
231
  return
@@ -219,39 +257,80 @@ def segmented_sort_pairs(
219
257
  if keys.device.is_cpu:
220
258
  if keys.dtype == wp.int32 and values.dtype == wp.int32:
221
259
  runtime.core.segmented_sort_pairs_int_host(
222
- keys.ptr, values.ptr, count, segment_start_indices_ptr, segment_end_indices_ptr, num_segments
260
+ keys.ptr,
261
+ values.ptr,
262
+ count,
263
+ segment_start_indices_ptr,
264
+ segment_end_indices_ptr,
265
+ num_segments,
223
266
  )
224
267
  elif keys.dtype == wp.float32 and values.dtype == wp.int32:
225
268
  runtime.core.segmented_sort_pairs_float_host(
226
- keys.ptr, values.ptr, count, segment_start_indices_ptr, segment_end_indices_ptr, num_segments
269
+ keys.ptr,
270
+ values.ptr,
271
+ count,
272
+ segment_start_indices_ptr,
273
+ segment_end_indices_ptr,
274
+ num_segments,
227
275
  )
228
276
  else:
229
- raise RuntimeError("Unsupported data type")
277
+ raise RuntimeError(f"Unsupported data type: {type_repr(keys.dtype)}")
230
278
  elif keys.device.is_cuda:
231
279
  if keys.dtype == wp.int32 and values.dtype == wp.int32:
232
280
  runtime.core.segmented_sort_pairs_int_device(
233
- keys.ptr, values.ptr, count, segment_start_indices_ptr, segment_end_indices_ptr, num_segments
281
+ keys.ptr,
282
+ values.ptr,
283
+ count,
284
+ segment_start_indices_ptr,
285
+ segment_end_indices_ptr,
286
+ num_segments,
234
287
  )
235
288
  elif keys.dtype == wp.float32 and values.dtype == wp.int32:
236
289
  runtime.core.segmented_sort_pairs_float_device(
237
- keys.ptr, values.ptr, count, segment_start_indices_ptr, segment_end_indices_ptr, num_segments
290
+ keys.ptr,
291
+ values.ptr,
292
+ count,
293
+ segment_start_indices_ptr,
294
+ segment_end_indices_ptr,
295
+ num_segments,
238
296
  )
239
297
  else:
240
- raise RuntimeError("Unsupported data type")
298
+ raise RuntimeError(f"Unsupported data type: {type_repr(keys.dtype)}")
241
299
 
242
300
 
243
301
  def runlength_encode(values, run_values, run_lengths, run_count=None, value_count=None):
302
+ """Perform run-length encoding on an array.
303
+
304
+ This function compresses an array by replacing consecutive identical values with a single value
305
+ and its count. For example, [1,1,1,2,2,3] becomes values=[1,2,3] and lengths=[3,2,1].
306
+
307
+ Args:
308
+ values (wp.array): Input array to encode. Must be of type int32.
309
+ run_values (wp.array): Output array to store unique values. Must be at least value_count in size.
310
+ run_lengths (wp.array): Output array to store run lengths. Must be at least value_count in size.
311
+ run_count (wp.array, optional): Optional output array to store the number of runs.
312
+ If None, returns the count as an integer.
313
+ value_count (int, optional): Number of values to process. If None, processes entire array.
314
+
315
+ Returns:
316
+ int or wp.array: Number of runs if run_count is None, otherwise returns run_count array.
317
+
318
+ Raises:
319
+ RuntimeError: If array storage devices don't match, if storage size is insufficient, or if data types are unsupported.
320
+ """
244
321
  if run_values.device != values.device or run_lengths.device != values.device:
245
- raise RuntimeError("Array storage devices do not match")
322
+ raise RuntimeError("run_values, run_lengths and values storage devices do not match")
246
323
 
247
324
  if value_count is None:
248
325
  value_count = values.size
249
326
 
250
327
  if run_values.size < value_count or run_lengths.size < value_count:
251
- raise RuntimeError("Output array storage sizes must be at least equal to value_count")
328
+ raise RuntimeError(f"Output array storage sizes must be at least equal to value_count ({value_count})")
252
329
 
253
- if values.dtype != run_values.dtype:
254
- raise RuntimeError("values and run_values data types do not match")
330
+ if not types_equal(values.dtype, run_values.dtype):
331
+ raise RuntimeError(
332
+ f"values and run_values data types do not match ({type_repr(values.dtype)} vs {type_repr(run_values.dtype)})"
333
+ )
255
334
 
256
335
  if run_lengths.dtype != wp.int32:
257
336
  raise RuntimeError("run_lengths array must be of type int32")
@@ -270,7 +349,7 @@ def runlength_encode(values, run_values, run_lengths, run_count=None, value_coun
270
349
  raise RuntimeError("run_count array must be of type int32")
271
350
  if value_count == 0:
272
351
  run_count.zero_()
273
- return 0
352
+ return run_count
274
353
  host_return = False
275
354
 
276
355
  from warp.context import runtime
@@ -281,20 +360,39 @@ def runlength_encode(values, run_values, run_lengths, run_count=None, value_coun
281
360
  values.ptr, run_values.ptr, run_lengths.ptr, run_count.ptr, value_count
282
361
  )
283
362
  else:
284
- raise RuntimeError("Unsupported data type")
363
+ raise RuntimeError(f"Unsupported data type: {type_repr(values.dtype)}")
285
364
  elif values.device.is_cuda:
286
365
  if values.dtype == wp.int32:
287
366
  runtime.core.runlength_encode_int_device(
288
367
  values.ptr, run_values.ptr, run_lengths.ptr, run_count.ptr, value_count
289
368
  )
290
369
  else:
291
- raise RuntimeError("Unsupported data type")
370
+ raise RuntimeError(f"Unsupported data type: {type_repr(values.dtype)}")
292
371
 
293
372
  if host_return:
294
373
  return int(run_count.numpy()[0])
374
+ return run_count
295
375
 
296
376
 
297
377
  def array_sum(values, out=None, value_count=None, axis=None):
378
+ """Compute the sum of array elements.
379
+
380
+ This function computes the sum of array elements, optionally along a specified axis.
381
+ The operation can be performed on the entire array or along a specific dimension.
382
+
383
+ Args:
384
+ values (wp.array): Input array to sum. Must be of type float32 or float64.
385
+ out (wp.array, optional): Output array to store results. If None, a new array is created.
386
+ value_count (int, optional): Number of elements to process. If None, processes entire array.
387
+ axis (int, optional): Axis along which to compute sum. If None, computes sum of all elements.
388
+
389
+ Returns:
390
+ wp.array or float: The sum result. Returns a float if axis is None and out is None,
391
+ otherwise returns the output array.
392
+
393
+ Raises:
394
+ RuntimeError: If output array storage device or data type is incompatible with input array.
395
+ """
298
396
  if value_count is None:
299
397
  if axis is None:
300
398
  value_count = values.size
@@ -310,7 +408,7 @@ def array_sum(values, out=None, value_count=None, axis=None):
310
408
 
311
409
  output_shape = tuple(output_dim(ax, dim) for ax, dim in enumerate(values.shape))
312
410
 
313
- type_length = wp.types.type_length(values.dtype)
411
+ type_size = wp.types.type_size(values.dtype)
314
412
  scalar_type = wp.types.type_scalar_type(values.dtype)
315
413
 
316
414
  # User can provide a device output array for storing the number of runs
@@ -341,48 +439,67 @@ def array_sum(values, out=None, value_count=None, axis=None):
341
439
  elif scalar_type == wp.float64:
342
440
  native_func = runtime.core.array_sum_double_host
343
441
  else:
344
- raise RuntimeError("Unsupported data type")
442
+ raise RuntimeError(f"Unsupported data type: {type_repr(values.dtype)}")
345
443
  elif values.device.is_cuda:
346
444
  if scalar_type == wp.float32:
347
445
  native_func = runtime.core.array_sum_float_device
348
446
  elif scalar_type == wp.float64:
349
447
  native_func = runtime.core.array_sum_double_device
350
448
  else:
351
- raise RuntimeError("Unsupported data type")
449
+ raise RuntimeError(f"Unsupported data type: {type_repr(values.dtype)}")
352
450
 
353
451
  if axis is None:
354
452
  stride = wp.types.type_size_in_bytes(values.dtype)
355
- native_func(values.ptr, out.ptr, value_count, stride, type_length)
453
+ native_func(values.ptr, out.ptr, value_count, stride, type_size)
356
454
 
357
455
  if host_return:
358
456
  return out.numpy()[0]
359
- else:
360
- stride = values.strides[axis]
361
- for idx in np.ndindex(output_shape):
362
- out_offset = sum(i * s for i, s in zip(idx, out.strides))
363
- val_offset = sum(i * s for i, s in zip(idx, values.strides))
364
-
365
- native_func(
366
- values.ptr + val_offset,
367
- out.ptr + out_offset,
368
- value_count,
369
- stride,
370
- type_length,
371
- )
457
+ return out
372
458
 
373
- if host_return:
374
- return out
459
+ stride = values.strides[axis]
460
+ for idx in np.ndindex(output_shape):
461
+ out_offset = sum(i * s for i, s in zip(idx, out.strides))
462
+ val_offset = sum(i * s for i, s in zip(idx, values.strides))
463
+
464
+ native_func(
465
+ values.ptr + val_offset,
466
+ out.ptr + out_offset,
467
+ value_count,
468
+ stride,
469
+ type_size,
470
+ )
471
+
472
+ return out
375
473
 
376
474
 
377
475
  def array_inner(a, b, out=None, count=None, axis=None):
476
+ """Compute the inner product of two arrays.
477
+
478
+ This function computes the dot product between two arrays, optionally along a specified axis.
479
+ The operation can be performed on the entire arrays or along a specific dimension.
480
+
481
+ Args:
482
+ a (wp.array): First input array.
483
+ b (wp.array): Second input array. Must match shape and type of a.
484
+ out (wp.array, optional): Output array to store results. If None, a new array is created.
485
+ count (int, optional): Number of elements to process. If None, processes entire arrays.
486
+ axis (int, optional): Axis along which to compute inner product. If None, computes on flattened arrays.
487
+
488
+ Returns:
489
+ wp.array or float: The inner product result. Returns a float if axis is None and out is None,
490
+ otherwise returns the output array.
491
+
492
+ Raises:
493
+ RuntimeError: If array storage devices, sizes, or data types are incompatible.
494
+ """
378
495
  if a.size != b.size:
379
- raise RuntimeError("Array storage sizes do not match")
496
+ raise RuntimeError(f"A and b array storage sizes do not match ({a.size} vs {b.size})")
380
497
 
381
498
  if a.device != b.device:
382
- raise RuntimeError("Array storage devices do not match")
499
+ raise RuntimeError(f"A and b array storage devices do not match ({a.device} vs {b.device})")
383
500
 
384
- if a.dtype != b.dtype:
385
- raise RuntimeError("Array data types do not match")
501
+ if not types_equal(a.dtype, b.dtype):
502
+ raise RuntimeError(f"A and b array data types do not match ({type_repr(a.dtype)} vs {type_repr(b.dtype)})")
386
503
 
387
504
  if count is None:
388
505
  if axis is None:
@@ -399,7 +516,7 @@ def array_inner(a, b, out=None, count=None, axis=None):
399
516
 
400
517
  output_shape = tuple(output_dim(ax, dim) for ax, dim in enumerate(a.shape))
401
518
 
402
- type_length = wp.types.type_length(a.dtype)
519
+ type_size = wp.types.type_size(a.dtype)
403
520
  scalar_type = wp.types.type_scalar_type(a.dtype)
404
521
 
405
522
  # User can provide a device output array for storing the number of runs
@@ -430,43 +547,43 @@ def array_inner(a, b, out=None, count=None, axis=None):
430
547
  elif scalar_type == wp.float64:
431
548
  native_func = runtime.core.array_inner_double_host
432
549
  else:
433
- raise RuntimeError("Unsupported data type")
550
+ raise RuntimeError(f"Unsupported data type: {type_repr(a.dtype)}")
434
551
  elif a.device.is_cuda:
435
552
  if scalar_type == wp.float32:
436
553
  native_func = runtime.core.array_inner_float_device
437
554
  elif scalar_type == wp.float64:
438
555
  native_func = runtime.core.array_inner_double_device
439
556
  else:
440
- raise RuntimeError("Unsupported data type")
557
+ raise RuntimeError(f"Unsupported data type: {type_repr(a.dtype)}")
441
558
 
442
559
  if axis is None:
443
560
  stride_a = wp.types.type_size_in_bytes(a.dtype)
444
561
  stride_b = wp.types.type_size_in_bytes(b.dtype)
445
- native_func(a.ptr, b.ptr, out.ptr, count, stride_a, stride_b, type_length)
562
+ native_func(a.ptr, b.ptr, out.ptr, count, stride_a, stride_b, type_size)
446
563
 
447
564
  if host_return:
448
565
  return out.numpy()[0]
449
- else:
450
- stride_a = a.strides[axis]
451
- stride_b = b.strides[axis]
452
-
453
- for idx in np.ndindex(output_shape):
454
- out_offset = sum(i * s for i, s in zip(idx, out.strides))
455
- a_offset = sum(i * s for i, s in zip(idx, a.strides))
456
- b_offset = sum(i * s for i, s in zip(idx, b.strides))
457
-
458
- native_func(
459
- a.ptr + a_offset,
460
- b.ptr + b_offset,
461
- out.ptr + out_offset,
462
- count,
463
- stride_a,
464
- stride_b,
465
- type_length,
466
- )
566
+ return out
467
567
 
468
- if host_return:
469
- return out
568
+ stride_a = a.strides[axis]
569
+ stride_b = b.strides[axis]
570
+
571
+ for idx in np.ndindex(output_shape):
572
+ out_offset = sum(i * s for i, s in zip(idx, out.strides))
573
+ a_offset = sum(i * s for i, s in zip(idx, a.strides))
574
+ b_offset = sum(i * s for i, s in zip(idx, b.strides))
575
+
576
+ native_func(
577
+ a.ptr + a_offset,
578
+ b.ptr + b_offset,
579
+ out.ptr + out_offset,
580
+ count,
581
+ stride_a,
582
+ stride_b,
583
+ type_size,
584
+ )
585
+
586
+ return out
470
587
 
471
588
 
472
589
  @wp.kernel
@@ -479,8 +596,28 @@ def _array_cast_kernel(
479
596
 
480
597
 
481
598
  def array_cast(in_array, out_array, count=None):
599
+ """Cast elements from one array to another array with a different data type.
600
+
601
+ This function performs element-wise casting from the input array to the output array.
602
+ The arrays must have the same number of dimensions and data type shapes. If they don't match,
603
+ the arrays will be flattened and casting will be performed at the scalar level.
604
+
605
+ Args:
606
+ in_array (wp.array): Input array to cast from.
607
+ out_array (wp.array): Output array to cast to. Must have the same device as in_array.
608
+ count (int, optional): Number of elements to process. If None, processes entire array.
609
+ For multi-dimensional arrays, partial casting is not supported.
610
+
611
+ Raises:
612
+ RuntimeError: If arrays have different devices or if attempting partial casting
613
+ on multi-dimensional arrays.
614
+
615
+ Note:
616
+ If the input and output arrays have the same data type, this function will
617
+ simply copy the data without any conversion.
618
+ """
482
619
  if in_array.device != out_array.device:
483
- raise RuntimeError("Array storage devices do not match")
620
+ raise RuntimeError(f"Array storage devices do not match ({in_array.device} vs {out_array.device})")
484
621
 
485
622
  in_array_data_shape = getattr(in_array.dtype, "_shape_", ())
486
623
  out_array_data_shape = getattr(out_array.dtype, "_shape_", ())
@@ -491,8 +628,8 @@ def array_cast(in_array, out_array, count=None):
491
628
  in_array = in_array.flatten()
492
629
  out_array = out_array.flatten()
493
630
 
494
- in_array_data_length = warp.types.type_length(in_array.dtype)
495
- out_array_data_length = warp.types.type_length(out_array.dtype)
631
+ in_array_data_length = warp.types.type_size(in_array.dtype)
632
+ out_array_data_length = warp.types.type_size(out_array.dtype)
496
633
  in_array_scalar_type = wp.types.type_scalar_type(in_array.dtype)
497
634
  out_array_scalar_type = wp.types.type_scalar_type(out_array.dtype)
498
635
 
@@ -534,6 +671,430 @@ def array_cast(in_array, out_array, count=None):
534
671
  wp.launch(kernel=_array_cast_kernel, dim=dim, inputs=[out_array, in_array], device=out_array.device)
535
672
 
536
673
 
674
+ def create_warp_function(func: Callable) -> tuple[wp.Function, warp.context.Module]:
675
+ """Create a Warp function from a Python function.
676
+
677
+ Args:
678
+ func (Callable): A Python function to be converted to a Warp function.
679
+
680
+ Returns:
681
+ wp.Function: A Warp function created from the input function.
682
+ """
683
+
684
+ from .codegen import Adjoint, get_full_arg_spec
685
+
686
+ def unique_name(code: str):
687
+ return "func_" + hex(hash(code))[-8:]
688
+
689
+ # Create a Warp function from the input function
690
+ source = None
691
+ argspec = get_full_arg_spec(func)
692
+ key = getattr(func, "__name__", None)
693
+ if key is None:
694
+ source, _ = Adjoint.extract_function_source(func)
695
+ key = unique_name(source)
696
+ elif key == "<lambda>":
697
+ body = Adjoint.extract_lambda_source(func, only_body=True)
698
+ if body is None:
699
+ raise ValueError("Could not extract lambda source code")
700
+ key = unique_name(body)
701
+ source = f"def {key}({', '.join(argspec.args)}):\n return {body}"
702
+ else:
703
+ # use the qualname of the function as the key
704
+ key = getattr(func, "__qualname__", key)
705
+ key = key.replace(".", "_").replace(" ", "_").replace("<", "").replace(">", "_")
706
+
707
+ module = warp.context.get_module(f"map_{key}")
708
+ func = wp.Function(
709
+ func,
710
+ namespace="",
711
+ module=module,
712
+ key=key,
713
+ source=source,
714
+ overloaded_annotations=dict.fromkeys(argspec.args, Any),
715
+ )
716
+ return func, module
717
+
718
+
719
+ def broadcast_shapes(shapes: list[tuple[int]]) -> tuple[int]:
720
+ """Broadcast a list of shapes to a common shape.
721
+
722
+ Following the broadcasting rules of NumPy, two shapes are compatible when:
723
+ starting from the trailing dimension,
724
+ 1. the two dimensions are equal, or
725
+ 2. one of the dimensions is 1.
726
+
727
+ Example:
728
+ >>> broadcast_shapes([(3, 1, 4), (5, 4)])
729
+ (3, 5, 4)
730
+
731
+ Returns:
732
+ tuple[int]: The broadcasted shape.
733
+
734
+ Raises:
735
+ ValueError: If the shapes are not broadcastable.
736
+ """
737
+ ref = shapes[0]
738
+ for shape in shapes[1:]:
739
+ broad = []
740
+ for j in range(1, max(len(ref), len(shape)) + 1):
741
+ if j <= len(ref) and j <= len(shape):
742
+ s = shape[-j]
743
+ r = ref[-j]
744
+ if s == r:
745
+ broad.append(s)
746
+ elif s == 1 or r == 1:
747
+ broad.append(max(s, r))
748
+ else:
749
+ raise ValueError(f"Shapes {ref} and {shape} are not broadcastable")
750
+ elif j <= len(ref):
751
+ broad.append(ref[-j])
752
+ else:
753
+ broad.append(shape[-j])
754
+ ref = tuple(reversed(broad))
755
+ return ref
756
+
757
+
758
+ def map(
759
+ func: Callable | wp.Function,
760
+ *inputs: Array[DType] | Any,
761
+ out: Array[DType] | list[Array[DType]] | None = None,
762
+ return_kernel: bool = False,
763
+ block_dim=256,
764
+ device: Devicelike = None,
765
+ ) -> Array[DType] | list[Array[DType]] | wp.Kernel:
766
+ """
767
+ Map a function over the elements of one or more arrays.
768
+
769
+ You can use a Warp function, a regular Python function, or a lambda expression to map it to a set of arrays.
770
+
771
+ .. testcode::
772
+
773
+ a = wp.array([1, 2, 3], dtype=wp.float32)
774
+ b = wp.array([4, 5, 6], dtype=wp.float32)
775
+ c = wp.array([7, 8, 9], dtype=wp.float32)
776
+ result = wp.map(lambda x, y, z: x + 2.0 * y - z, a, b, c)
777
+ print(result)
778
+
779
+ .. testoutput::
780
+
781
+ [2. 4. 6.]
782
+
783
+ Clamp values in an array in place:
784
+
785
+ .. testcode::
786
+
787
+ xs = wp.array([-1.0, 0.0, 1.0], dtype=wp.float32)
788
+ wp.map(wp.clamp, xs, -0.5, 0.5, out=xs)
789
+ print(xs)
790
+
791
+ .. testoutput::
792
+
793
+ [-0.5 0. 0.5]
794
+
795
+ Note that only one of the inputs must be a Warp array. For example, it is possible
796
+ vectorize the function :func:`warp.transform_point` over a collection of points
797
+ with a given input transform as follows:
798
+
799
+ .. code-block:: python
800
+
801
+ tf = wp.transform((1.0, 2.0, 3.0), wp.quat_rpy(0.2, -0.6, 0.1))
802
+ points = wp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=wp.vec3)
803
+ transformed = wp.map(wp.transform_point, tf, points)
804
+
805
+ Besides regular Warp arrays, other array types, such as the ``indexedarray``, are supported as well:
806
+
807
+ .. testcode::
808
+
809
+ arr = wp.array(data=np.arange(10, dtype=np.float32))
810
+ indices = wp.array([1, 3, 5, 7, 9], dtype=int)
811
+ iarr = wp.indexedarray1d(arr, [indices])
812
+ out = wp.map(lambda x: x * 10.0, iarr)
813
+ print(out)
814
+
815
+ .. testoutput::
816
+
817
+ [10. 30. 50. 70. 90.]
818
+
819
+ If multiple arrays are provided, the
820
+ `NumPy broadcasting rules <https://numpy.org/doc/stable/user/basics.broadcasting.html>`_
821
+ are applied to determine the shape of the output array.
822
+ Two shapes are compatible when:
823
+ starting from the trailing dimension,
824
+
825
+ 1. the two dimensions are equal, or
826
+ 2. one of the dimensions is 1.
827
+
828
+ For example, given arrays of shapes ``(3, 1, 4)`` and ``(5, 4)``, the broadcasted
829
+ shape is ``(3, 5, 4)``.
830
+
831
+ If no array(s) are provided to the ``out`` argument, the output array(s) are created automatically.
832
+ The data type(s) of the output array(s) are determined by the type of the return value(s) of
833
+ the function. The ``requires_grad`` flag for an automatically created output array is set to ``True``
834
+ if any of the input arrays have it set to ``True`` and the respective output array's ``dtype`` is a type that
835
+ supports differentiation.
836
+
837
+ Args:
838
+ func (Callable | Function): The function to map over the arrays.
839
+ *inputs (array | Any): The input arrays or values to pass to the function.
840
+ out (array | list[array] | None): Optional output array(s) to store the result(s). If None, the output array(s) will be created automatically.
841
+ return_kernel (bool): If True, only return the generated kernel without performing the mapping operation.
842
+ block_dim (int): The block dimension for the kernel launch.
843
+ device (Devicelike): The device on which to run the kernel.
844
+
845
+ Returns:
846
+ array | list[array] | Kernel:
847
+ The resulting array(s) of the mapping. If ``return_kernel`` is True, only returns the kernel used for mapping.
848
+ """
849
+
850
+ import builtins
851
+
852
+ from .codegen import Adjoint, Struct, StructInstance
853
+ from .types import (
854
+ is_array,
855
+ type_is_matrix,
856
+ type_is_quaternion,
857
+ type_is_transformation,
858
+ type_is_vector,
859
+ type_repr,
860
+ type_to_warp,
861
+ types_equal,
862
+ )
863
+
864
+ # mapping from struct name to its Python definition
865
+ referenced_modules: dict[str, ModuleType] = {}
866
+
867
+ def type_to_code(wp_type) -> str:
868
+ """Returns the string representation of a given Warp type."""
869
+ if is_array(wp_type):
870
+ return f"warp.array(ndim={wp_type.ndim}, dtype={type_to_code(wp_type.dtype)})"
871
+ if isinstance(wp_type, Struct):
872
+ key = f"{wp_type.__module__}.{wp_type.key}"
873
+ module = sys.modules.get(wp_type.__module__, None)
874
+ if module is not None:
875
+ referenced_modules[wp_type.__module__] = module
876
+ return key
877
+ if type_is_transformation(wp_type):
878
+ return f"warp.types.transformation(dtype={type_to_code(wp_type._wp_scalar_type_)})"
879
+ if type_is_quaternion(wp_type):
880
+ return f"warp.types.quaternion(dtype={type_to_code(wp_type._wp_scalar_type_)})"
881
+ if type_is_vector(wp_type):
882
+ return f"warp.types.vector(length={wp_type._shape_[0]}, dtype={type_to_code(wp_type._wp_scalar_type_)})"
883
+ if type_is_matrix(wp_type):
884
+ return f"warp.types.matrix(shape=({wp_type._shape_[0]}, {wp_type._shape_[1]}), dtype={type_to_code(wp_type._wp_scalar_type_)})"
885
+ if wp_type == builtins.bool:
886
+ return "bool"
887
+ if wp_type == builtins.float:
888
+ return "float"
889
+ if wp_type == builtins.int:
890
+ return "int"
891
+
892
+ name = getattr(wp_type, "__name__", None)
893
+ if name is None:
894
+ return type_repr(wp_type)
895
+ name = getattr(wp_type, "__qualname__", name)
896
+ module = getattr(wp_type, "__module__", None)
897
+ if module is not None:
898
+ referenced_modules[wp_type.__module__] = module
899
+ return wp_type.__module__ + "." + name
900
+
901
+ def get_warp_type(value):
902
+ dtype = type(value)
903
+ if issubclass(dtype, StructInstance):
904
+ # a struct
905
+ return value._cls
906
+ return type_to_warp(dtype)
907
+
908
+ # gather the arrays in the inputs
909
+ array_shapes = [a.shape for a in inputs if is_array(a)]
910
+ if len(array_shapes) == 0:
911
+ raise ValueError("map requires at least one warp.array input")
912
+ # broadcast the shapes of the arrays
913
+ out_shape = broadcast_shapes(array_shapes)
914
+
915
+ module = None
916
+ out_dtypes = None
917
+ skip_arg_type_checks = False
918
+ if isinstance(func, wp.Function):
919
+ func_name = func.key
920
+ wp_func = func
921
+ else:
922
+ # check if op is a callable function
923
+ if not callable(func):
924
+ raise TypeError("func must be a callable function or a warp.Function")
925
+ wp_func, module = create_warp_function(func)
926
+ func_name = wp_func.key
927
+ # we created a generic function here (arg types are all Any)
928
+ skip_arg_type_checks = True
929
+ if module is None:
930
+ module = warp.context.get_module(f"map_{func_name}")
931
+
932
+ arg_names = list(wp_func.input_types.keys())
933
+ # determine output dtype
934
+ if wp_func.value_func is not None or wp_func.value_type is not None:
935
+ arg_types = {}
936
+ arg_values = {}
937
+ for i, arg_name in enumerate(arg_names):
938
+ if is_array(inputs[i]):
939
+ # we will pass an element of the array to the function
940
+ arg_types[arg_name] = inputs[i].dtype
941
+ if device is None:
942
+ device = inputs[i].device
943
+ else:
944
+ # we pass the input value directly to the function
945
+ arg_types[arg_name] = get_warp_type(inputs[i])
946
+ func_or_none = wp_func.get_overload(list(arg_types.values()), {})
947
+ if func_or_none is None:
948
+ raise TypeError(
949
+ f"Function {func_name} does not support the provided argument types {', '.join(type_repr(t) for t in arg_types.values())}"
950
+ )
951
+ func = func_or_none
952
+ if func.value_func is not None:
953
+ out_dtype = func.value_func(arg_types, arg_values)
954
+ else:
955
+ out_dtype = func.value_type
956
+ if isinstance(out_dtype, tuple) or isinstance(out_dtype, list):
957
+ out_dtypes = out_dtype
958
+ else:
959
+ out_dtypes = (out_dtype,)
960
+ else:
961
+ # try to evaluate the function to determine the output type
962
+ args = []
963
+ arg_types = wp_func.input_types
964
+ if len(inputs) != len(arg_types):
965
+ raise TypeError(
966
+ f"Number of input arguments ({len(inputs)}) does not match expected number of function arguments ({len(arg_types)})"
967
+ )
968
+ for (arg_name, arg_type), input in zip(arg_types.items(), inputs):
969
+ if is_array(input):
970
+ if not skip_arg_type_checks and not types_equal(input.dtype, arg_type):
971
+ raise TypeError(
972
+ f'Incorrect input provided for argument "{arg_name}": received array of dtype {type_repr(input.dtype)}, expected {type_repr(arg_type)}'
973
+ )
974
+ args.append(input.dtype())
975
+ if device is None:
976
+ device = input.device
977
+ else:
978
+ if not skip_arg_type_checks and not types_equal(type(input), arg_type):
979
+ raise TypeError(
980
+ f'Incorrect input provided for argument "{arg_name}": received {type_repr(type(input))}, expected {type_repr(arg_type)}'
981
+ )
982
+ args.append(input)
983
+ result = wp_func(*args)
984
+ if result is None:
985
+ raise TypeError("The provided function must return a value")
986
+ if isinstance(result, tuple) or isinstance(result, list):
987
+ out_dtypes = tuple(get_warp_type(r) for r in result)
988
+ else:
989
+ out_dtypes = (get_warp_type(result),)
990
+
991
+ if out_dtypes is None:
992
+ raise TypeError("Could not determine the output type of the function, make sure it returns a value")
993
+
994
+ if out is None:
995
+ requires_grad = any(getattr(a, "requires_grad", False) for a in inputs if is_array(a))
996
+ outputs = []
997
+ for dtype in out_dtypes:
998
+ rg = requires_grad and Adjoint.is_differentiable_value_type(dtype)
999
+ outputs.append(wp.empty(out_shape, dtype=dtype, requires_grad=rg, device=device))
1000
+ elif len(out_dtypes) == 1 and is_array(out):
1001
+ if not types_equal(out.dtype, out_dtypes[0]):
1002
+ raise TypeError(
1003
+ f"Output array dtype {type_repr(out.dtype)} does not match expected dtype {type_repr(out_dtypes[0])}"
1004
+ )
1005
+ if out.shape != out_shape:
1006
+ raise TypeError(f"Output array shape {out.shape} does not match expected shape {out_shape}")
1007
+ outputs = [out]
1008
+ elif len(out_dtypes) > 1:
1009
+ if isinstance(out, tuple) or isinstance(out, list):
1010
+ if len(out) != len(out_dtypes):
1011
+ raise TypeError(
1012
+ f"Number of provided output arrays ({len(out)}) does not match expected number of function outputs ({len(out_dtypes)})"
1013
+ )
1014
+ for i, a in enumerate(out):
1015
+ if not types_equal(a.dtype, out_dtypes[i]):
1016
+ raise TypeError(
1017
+ f"Output array {i} dtype {type_repr(a.dtype)} does not match expected dtype {type_repr(out_dtypes[i])}"
1018
+ )
1019
+ if a.shape != out_shape:
1020
+ raise TypeError(f"Output array {i} shape {a.shape} does not match expected shape {out_shape}")
1021
+ outputs = list(out)
1022
+ else:
1023
+ raise TypeError(
1024
+ f"Invalid output provided, expected {len(out_dtypes)} Warp arrays with shape {out_shape} and dtypes ({', '.join(type_repr(t) for t in out_dtypes)})"
1025
+ )
1026
+
1027
+ # create code for a kernel
1028
+ code = """def map_kernel({kernel_args}):
1029
+ {tids} = wp.tid()
1030
+ {load_args}
1031
+ """
1032
+ if len(outputs) == 1:
1033
+ code += "__out_0[{tids}] = {func_name}({arg_names})"
1034
+ else:
1035
+ code += ", ".join(f"__o_{i}" for i in range(len(outputs)))
1036
+ code += " = {func_name}({arg_names})\n"
1037
+ for i in range(len(outputs)):
1038
+ code += f" __out_{i}" + "[{tids}]" + f" = __o_{i}\n"
1039
+
1040
+ tids = [f"__tid_{i}" for i in range(len(out_shape))]
1041
+
1042
+ load_args = []
1043
+ kernel_args = []
1044
+ for arg_name, input in zip(arg_names, inputs):
1045
+ if is_array(input):
1046
+ arr_name = f"{arg_name}_array"
1047
+ array_type_name = type(input).__name__
1048
+ kernel_args.append(
1049
+ f"{arr_name}: wp.{array_type_name}(dtype={type_to_code(input.dtype)}, ndim={input.ndim})"
1050
+ )
1051
+ shape = input.shape
1052
+ indices = []
1053
+ for i in range(1, len(shape) + 1):
1054
+ if shape[-i] == 1:
1055
+ indices.append("0")
1056
+ else:
1057
+ indices.append(tids[-i])
1058
+
1059
+ load_args.append(f"{arg_name} = {arr_name}[{', '.join(reversed(indices))}]")
1060
+ else:
1061
+ kernel_args.append(f"{arg_name}: {type_to_code(type(input))}")
1062
+ for i, o in enumerate(outputs):
1063
+ array_type_name = type(o).__name__
1064
+ kernel_args.append(f"__out_{i}: wp.{array_type_name}(dtype={type_to_code(o.dtype)}, ndim={o.ndim})")
1065
+ code = code.format(
1066
+ func_name=func_name,
1067
+ kernel_args=", ".join(kernel_args),
1068
+ arg_names=", ".join(arg_names),
1069
+ tids=", ".join(tids),
1070
+ load_args="\n ".join(load_args),
1071
+ )
1072
+ namespace = {}
1073
+ namespace.update({"wp": wp, "warp": wp, func_name: wp_func, "Any": Any})
1074
+ namespace.update(referenced_modules)
1075
+ exec(code, namespace)
1076
+
1077
+ kernel = wp.Kernel(namespace["map_kernel"], key="map_kernel", source=code, module=module)
1078
+ if return_kernel:
1079
+ return kernel
1080
+
1081
+ wp.launch(
1082
+ kernel,
1083
+ dim=out_shape,
1084
+ inputs=inputs,
1085
+ outputs=outputs,
1086
+ block_dim=block_dim,
1087
+ device=device,
1088
+ )
1089
+
1090
+ if len(outputs) == 1:
1091
+ o = outputs[0]
1092
+ else:
1093
+ o = outputs
1094
+
1095
+ return o
1096
+
1097
+
537
1098
  # code snippet for invoking cProfile
538
1099
  # cp = cProfile.Profile()
539
1100
  # cp.enable()
@@ -634,7 +1195,7 @@ def mem_report(): # pragma: no cover
634
1195
  element_size = tensor.storage().element_size()
635
1196
  mem = numel * element_size / 1024 / 1024 # 32bit=4Byte, MByte
636
1197
  total_mem += mem
637
- print("Type: %s Total Tensors: %d \tUsed Memory Space: %.2f MBytes" % (mem_type, total_numel, total_mem))
1198
+ print(f"Type: {mem_type:<4} | Total Tensors: {total_numel:>8} | Used Memory: {total_mem:>8.2f} MB")
638
1199
 
639
1200
  import gc
640
1201
 
@@ -712,7 +1273,7 @@ class ScopedStream:
712
1273
  device (Device): The device associated with the stream.
713
1274
  """
714
1275
 
715
- def __init__(self, stream: Optional[wp.Stream], sync_enter: bool = True, sync_exit: bool = False):
1276
+ def __init__(self, stream: wp.Stream | None, sync_enter: bool = True, sync_exit: bool = False):
716
1277
  """Initializes the context manager with a stream and synchronization options.
717
1278
 
718
1279
  Args:
@@ -765,12 +1326,12 @@ class ScopedTimer:
765
1326
  active: bool = True,
766
1327
  print: bool = True,
767
1328
  detailed: bool = False,
768
- dict: Optional[Dict[str, List[float]]] = None,
1329
+ dict: dict[str, list[float]] | None = None,
769
1330
  use_nvtx: bool = False,
770
- color: Union[int, str] = "rapids",
1331
+ color: int | str = "rapids",
771
1332
  synchronize: bool = False,
772
1333
  cuda_filter: int = 0,
773
- report_func: Optional[Callable[[List[TimingResult], str], None]] = None,
1334
+ report_func: Callable[[list[TimingResult], str], None] | None = None,
774
1335
  skip_tape: bool = False,
775
1336
  ):
776
1337
  """Context manager object for a timer
@@ -792,7 +1353,7 @@ class ScopedTimer:
792
1353
  Attributes:
793
1354
  extra_msg (str): Can be set to a string that will be added to the printout at context exit.
794
1355
  elapsed (float): The duration of the ``with`` block used with this object
795
- timing_results (List[TimingResult]): The list of activity timing results, if collection was requested using ``cuda_filter``
1356
+ timing_results (list[TimingResult]): The list of activity timing results, if collection was requested using ``cuda_filter``
796
1357
  """
797
1358
  self.name = name
798
1359
  self.active = active and self.enabled
@@ -986,12 +1547,12 @@ def check_p2p():
986
1547
  class timing_result_t(ctypes.Structure):
987
1548
  """CUDA timing struct for fetching values from C++"""
988
1549
 
989
- _fields_ = [
1550
+ _fields_ = (
990
1551
  ("context", ctypes.c_void_p),
991
1552
  ("name", ctypes.c_char_p),
992
1553
  ("filter", ctypes.c_int),
993
1554
  ("elapsed", ctypes.c_float),
994
- ]
1555
+ )
995
1556
 
996
1557
 
997
1558
  class TimingResult:
@@ -1025,7 +1586,7 @@ def timing_begin(cuda_filter: int = TIMING_ALL, synchronize: bool = True) -> Non
1025
1586
  warp.context.runtime.core.cuda_timing_begin(cuda_filter)
1026
1587
 
1027
1588
 
1028
- def timing_end(synchronize: bool = True) -> List[TimingResult]:
1589
+ def timing_end(synchronize: bool = True) -> list[TimingResult]:
1029
1590
  """End detailed activity timing.
1030
1591
 
1031
1592
  Parameters:
@@ -1071,7 +1632,7 @@ def timing_end(synchronize: bool = True) -> List[TimingResult]:
1071
1632
  return results
1072
1633
 
1073
1634
 
1074
- def timing_print(results: List[TimingResult], indent: str = "") -> None:
1635
+ def timing_print(results: list[TimingResult], indent: str = "") -> None:
1075
1636
  """Print timing results.
1076
1637
 
1077
1638
  Parameters: