warp-lang 1.7.2rc1__py3-none-manylinux_2_34_aarch64.whl → 1.8.0__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +125 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +257 -101
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +657 -223
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +97 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +107 -52
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +12 -17
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +99 -0
- warp/native/builtin.h +174 -31
- warp/native/coloring.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +3 -3
- warp/native/mat.h +5 -10
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/quat.h +28 -4
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/tile.h +583 -72
- warp/native/tile_radix_sort.h +1108 -0
- warp/native/tile_reduce.h +237 -2
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +6 -16
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +574 -51
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +58 -29
- warp/render/render_usd.py +124 -61
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +252 -78
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +751 -320
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +52 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +15 -1
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_atomic_cas.py +299 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +1 -24
- warp/tests/test_quat.py +6 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +51 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/tile/test_tile.py +420 -1
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_reduce.py +213 -0
- warp/tests/tile/test_tile_shared_memory.py +130 -1
- warp/tests/tile/test_tile_sort.py +117 -0
- warp/tests/unittest_suites.py +4 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/utils.py
CHANGED
|
@@ -21,7 +21,8 @@ import os
|
|
|
21
21
|
import sys
|
|
22
22
|
import time
|
|
23
23
|
import warnings
|
|
24
|
-
from
|
|
24
|
+
from types import ModuleType
|
|
25
|
+
from typing import Any, Callable
|
|
25
26
|
|
|
26
27
|
import numpy as np
|
|
27
28
|
|
|
@@ -29,6 +30,7 @@ import warp as wp
|
|
|
29
30
|
import warp.context
|
|
30
31
|
import warp.types
|
|
31
32
|
from warp.context import Devicelike
|
|
33
|
+
from warp.types import Array, DType, type_repr, types_equal
|
|
32
34
|
|
|
33
35
|
warnings_seen = set()
|
|
34
36
|
|
|
@@ -52,7 +54,7 @@ def warp_showwarning(message, category, filename, lineno, file=None, line=None):
|
|
|
52
54
|
|
|
53
55
|
if line:
|
|
54
56
|
line = line.strip()
|
|
55
|
-
s += "
|
|
57
|
+
s += f" {line}\n"
|
|
56
58
|
else:
|
|
57
59
|
# simple warning
|
|
58
60
|
s = f"Warp {category.__name__}: {message}\n"
|
|
@@ -96,14 +98,31 @@ def quat_between_vectors(a: wp.vec3, b: wp.vec3) -> wp.quat:
|
|
|
96
98
|
|
|
97
99
|
|
|
98
100
|
def array_scan(in_array, out_array, inclusive=True):
|
|
101
|
+
"""Perform a scan (prefix sum) operation on an array.
|
|
102
|
+
|
|
103
|
+
This function computes the inclusive or exclusive scan of the input array and stores the result in the output array.
|
|
104
|
+
The scan operation computes a running sum of elements in the array.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
in_array (wp.array): Input array to scan. Must be of type int32 or float32.
|
|
108
|
+
out_array (wp.array): Output array to store scan results. Must match input array type and size.
|
|
109
|
+
inclusive (bool, optional): If True, performs an inclusive scan (includes current element in sum).
|
|
110
|
+
If False, performs an exclusive scan (excludes current element). Defaults to True.
|
|
111
|
+
|
|
112
|
+
Raises:
|
|
113
|
+
RuntimeError: If array storage devices don't match, if storage size is insufficient, or if data types are unsupported.
|
|
114
|
+
"""
|
|
115
|
+
|
|
99
116
|
if in_array.device != out_array.device:
|
|
100
|
-
raise RuntimeError("
|
|
117
|
+
raise RuntimeError(f"In and out array storage devices do not match ({in_array.device} vs {out_array.device})")
|
|
101
118
|
|
|
102
119
|
if in_array.size != out_array.size:
|
|
103
|
-
raise RuntimeError("
|
|
120
|
+
raise RuntimeError(f"In and out array storage sizes do not match ({in_array.size} vs {out_array.size})")
|
|
104
121
|
|
|
105
|
-
if in_array.dtype
|
|
106
|
-
raise RuntimeError(
|
|
122
|
+
if not types_equal(in_array.dtype, out_array.dtype):
|
|
123
|
+
raise RuntimeError(
|
|
124
|
+
f"In and out array data types do not match ({type_repr(in_array.dtype)} vs {type_repr(out_array.dtype)})"
|
|
125
|
+
)
|
|
107
126
|
|
|
108
127
|
if in_array.size == 0:
|
|
109
128
|
return
|
|
@@ -116,25 +135,39 @@ def array_scan(in_array, out_array, inclusive=True):
|
|
|
116
135
|
elif in_array.dtype == wp.float32:
|
|
117
136
|
runtime.core.array_scan_float_host(in_array.ptr, out_array.ptr, in_array.size, inclusive)
|
|
118
137
|
else:
|
|
119
|
-
raise RuntimeError("Unsupported data type")
|
|
138
|
+
raise RuntimeError(f"Unsupported data type: {type_repr(in_array.dtype)}")
|
|
120
139
|
elif in_array.device.is_cuda:
|
|
121
140
|
if in_array.dtype == wp.int32:
|
|
122
141
|
runtime.core.array_scan_int_device(in_array.ptr, out_array.ptr, in_array.size, inclusive)
|
|
123
142
|
elif in_array.dtype == wp.float32:
|
|
124
143
|
runtime.core.array_scan_float_device(in_array.ptr, out_array.ptr, in_array.size, inclusive)
|
|
125
144
|
else:
|
|
126
|
-
raise RuntimeError("Unsupported data type")
|
|
145
|
+
raise RuntimeError(f"Unsupported data type: {type_repr(in_array.dtype)}")
|
|
127
146
|
|
|
128
147
|
|
|
129
148
|
def radix_sort_pairs(keys, values, count: int):
|
|
149
|
+
"""Sort key-value pairs using radix sort.
|
|
150
|
+
|
|
151
|
+
This function sorts pairs of arrays based on the keys array, maintaining the key-value
|
|
152
|
+
relationship. The sort is stable and operates in linear time.
|
|
153
|
+
The `keys` and `values` arrays must be large enough to accommodate 2*`count` elements.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
keys (wp.array): Array of keys to sort. Must be of type int32, float32, or int64.
|
|
157
|
+
values (wp.array): Array of values to sort along with keys. Must be of type int32.
|
|
158
|
+
count (int): Number of elements to sort.
|
|
159
|
+
|
|
160
|
+
Raises:
|
|
161
|
+
RuntimeError: If array storage devices don't match, if storage size is insufficient, or if data types are unsupported.
|
|
162
|
+
"""
|
|
130
163
|
if keys.device != values.device:
|
|
131
|
-
raise RuntimeError("
|
|
164
|
+
raise RuntimeError(f"Keys and values array storage devices do not match ({keys.device} vs {values.device})")
|
|
132
165
|
|
|
133
166
|
if count == 0:
|
|
134
167
|
return
|
|
135
168
|
|
|
136
169
|
if keys.size < 2 * count or values.size < 2 * count:
|
|
137
|
-
raise RuntimeError("
|
|
170
|
+
raise RuntimeError("Keys and values array storage must be large enough to contain 2*count elements")
|
|
138
171
|
|
|
139
172
|
from warp.context import runtime
|
|
140
173
|
|
|
@@ -146,7 +179,9 @@ def radix_sort_pairs(keys, values, count: int):
|
|
|
146
179
|
elif keys.dtype == wp.int64 and values.dtype == wp.int32:
|
|
147
180
|
runtime.core.radix_sort_pairs_int64_host(keys.ptr, values.ptr, count)
|
|
148
181
|
else:
|
|
149
|
-
raise RuntimeError(
|
|
182
|
+
raise RuntimeError(
|
|
183
|
+
f"Unsupported keys and values data types: {type_repr(keys.dtype)}, {type_repr(values.dtype)}"
|
|
184
|
+
)
|
|
150
185
|
elif keys.device.is_cuda:
|
|
151
186
|
if keys.dtype == wp.int32 and values.dtype == wp.int32:
|
|
152
187
|
runtime.core.radix_sort_pairs_int_device(keys.ptr, values.ptr, count)
|
|
@@ -155,7 +190,9 @@ def radix_sort_pairs(keys, values, count: int):
|
|
|
155
190
|
elif keys.dtype == wp.int64 and values.dtype == wp.int32:
|
|
156
191
|
runtime.core.radix_sort_pairs_int64_device(keys.ptr, values.ptr, count)
|
|
157
192
|
else:
|
|
158
|
-
raise RuntimeError(
|
|
193
|
+
raise RuntimeError(
|
|
194
|
+
f"Unsupported keys and values data types: {type_repr(keys.dtype)}, {type_repr(values.dtype)}"
|
|
195
|
+
)
|
|
159
196
|
|
|
160
197
|
|
|
161
198
|
def segmented_sort_pairs(
|
|
@@ -169,6 +206,7 @@ def segmented_sort_pairs(
|
|
|
169
206
|
|
|
170
207
|
This function performs a segmented sort of key-value pairs, where the sorting is done independently within each segment.
|
|
171
208
|
The segments are defined by their start and optionally end indices.
|
|
209
|
+
The `keys` and `values` arrays must be large enough to accommodate 2*`count` elements.
|
|
172
210
|
|
|
173
211
|
Args:
|
|
174
212
|
keys: Array of keys to sort. Must be of type int32 or float32.
|
|
@@ -187,7 +225,7 @@ def segmented_sort_pairs(
|
|
|
187
225
|
if segment_start_indices is not of type int32, or if data types are unsupported.
|
|
188
226
|
"""
|
|
189
227
|
if keys.device != values.device:
|
|
190
|
-
raise RuntimeError("Array storage devices do not match")
|
|
228
|
+
raise RuntimeError(f"Array storage devices do not match ({keys.device} vs {values.device})")
|
|
191
229
|
|
|
192
230
|
if count == 0:
|
|
193
231
|
return
|
|
@@ -219,39 +257,80 @@ def segmented_sort_pairs(
|
|
|
219
257
|
if keys.device.is_cpu:
|
|
220
258
|
if keys.dtype == wp.int32 and values.dtype == wp.int32:
|
|
221
259
|
runtime.core.segmented_sort_pairs_int_host(
|
|
222
|
-
keys.ptr,
|
|
260
|
+
keys.ptr,
|
|
261
|
+
values.ptr,
|
|
262
|
+
count,
|
|
263
|
+
segment_start_indices_ptr,
|
|
264
|
+
segment_end_indices_ptr,
|
|
265
|
+
num_segments,
|
|
223
266
|
)
|
|
224
267
|
elif keys.dtype == wp.float32 and values.dtype == wp.int32:
|
|
225
268
|
runtime.core.segmented_sort_pairs_float_host(
|
|
226
|
-
keys.ptr,
|
|
269
|
+
keys.ptr,
|
|
270
|
+
values.ptr,
|
|
271
|
+
count,
|
|
272
|
+
segment_start_indices_ptr,
|
|
273
|
+
segment_end_indices_ptr,
|
|
274
|
+
num_segments,
|
|
227
275
|
)
|
|
228
276
|
else:
|
|
229
|
-
raise RuntimeError("Unsupported data type")
|
|
277
|
+
raise RuntimeError(f"Unsupported data type: {type_repr(keys.dtype)}")
|
|
230
278
|
elif keys.device.is_cuda:
|
|
231
279
|
if keys.dtype == wp.int32 and values.dtype == wp.int32:
|
|
232
280
|
runtime.core.segmented_sort_pairs_int_device(
|
|
233
|
-
keys.ptr,
|
|
281
|
+
keys.ptr,
|
|
282
|
+
values.ptr,
|
|
283
|
+
count,
|
|
284
|
+
segment_start_indices_ptr,
|
|
285
|
+
segment_end_indices_ptr,
|
|
286
|
+
num_segments,
|
|
234
287
|
)
|
|
235
288
|
elif keys.dtype == wp.float32 and values.dtype == wp.int32:
|
|
236
289
|
runtime.core.segmented_sort_pairs_float_device(
|
|
237
|
-
keys.ptr,
|
|
290
|
+
keys.ptr,
|
|
291
|
+
values.ptr,
|
|
292
|
+
count,
|
|
293
|
+
segment_start_indices_ptr,
|
|
294
|
+
segment_end_indices_ptr,
|
|
295
|
+
num_segments,
|
|
238
296
|
)
|
|
239
297
|
else:
|
|
240
|
-
raise RuntimeError("Unsupported data type")
|
|
298
|
+
raise RuntimeError(f"Unsupported data type: {type_repr(keys.dtype)}")
|
|
241
299
|
|
|
242
300
|
|
|
243
301
|
def runlength_encode(values, run_values, run_lengths, run_count=None, value_count=None):
|
|
302
|
+
"""Perform run-length encoding on an array.
|
|
303
|
+
|
|
304
|
+
This function compresses an array by replacing consecutive identical values with a single value
|
|
305
|
+
and its count. For example, [1,1,1,2,2,3] becomes values=[1,2,3] and lengths=[3,2,1].
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
values (wp.array): Input array to encode. Must be of type int32.
|
|
309
|
+
run_values (wp.array): Output array to store unique values. Must be at least value_count in size.
|
|
310
|
+
run_lengths (wp.array): Output array to store run lengths. Must be at least value_count in size.
|
|
311
|
+
run_count (wp.array, optional): Optional output array to store the number of runs.
|
|
312
|
+
If None, returns the count as an integer.
|
|
313
|
+
value_count (int, optional): Number of values to process. If None, processes entire array.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
int or wp.array: Number of runs if run_count is None, otherwise returns run_count array.
|
|
317
|
+
|
|
318
|
+
Raises:
|
|
319
|
+
RuntimeError: If array storage devices don't match, if storage size is insufficient, or if data types are unsupported.
|
|
320
|
+
"""
|
|
244
321
|
if run_values.device != values.device or run_lengths.device != values.device:
|
|
245
|
-
raise RuntimeError("
|
|
322
|
+
raise RuntimeError("run_values, run_lengths and values storage devices do not match")
|
|
246
323
|
|
|
247
324
|
if value_count is None:
|
|
248
325
|
value_count = values.size
|
|
249
326
|
|
|
250
327
|
if run_values.size < value_count or run_lengths.size < value_count:
|
|
251
|
-
raise RuntimeError("Output array storage sizes must be at least equal to value_count")
|
|
328
|
+
raise RuntimeError(f"Output array storage sizes must be at least equal to value_count ({value_count})")
|
|
252
329
|
|
|
253
|
-
if values.dtype
|
|
254
|
-
raise RuntimeError(
|
|
330
|
+
if not types_equal(values.dtype, run_values.dtype):
|
|
331
|
+
raise RuntimeError(
|
|
332
|
+
f"values and run_values data types do not match ({type_repr(values.dtype)} vs {type_repr(run_values.dtype)})"
|
|
333
|
+
)
|
|
255
334
|
|
|
256
335
|
if run_lengths.dtype != wp.int32:
|
|
257
336
|
raise RuntimeError("run_lengths array must be of type int32")
|
|
@@ -270,7 +349,7 @@ def runlength_encode(values, run_values, run_lengths, run_count=None, value_coun
|
|
|
270
349
|
raise RuntimeError("run_count array must be of type int32")
|
|
271
350
|
if value_count == 0:
|
|
272
351
|
run_count.zero_()
|
|
273
|
-
return
|
|
352
|
+
return run_count
|
|
274
353
|
host_return = False
|
|
275
354
|
|
|
276
355
|
from warp.context import runtime
|
|
@@ -281,20 +360,39 @@ def runlength_encode(values, run_values, run_lengths, run_count=None, value_coun
|
|
|
281
360
|
values.ptr, run_values.ptr, run_lengths.ptr, run_count.ptr, value_count
|
|
282
361
|
)
|
|
283
362
|
else:
|
|
284
|
-
raise RuntimeError("Unsupported data type")
|
|
363
|
+
raise RuntimeError(f"Unsupported data type: {type_repr(values.dtype)}")
|
|
285
364
|
elif values.device.is_cuda:
|
|
286
365
|
if values.dtype == wp.int32:
|
|
287
366
|
runtime.core.runlength_encode_int_device(
|
|
288
367
|
values.ptr, run_values.ptr, run_lengths.ptr, run_count.ptr, value_count
|
|
289
368
|
)
|
|
290
369
|
else:
|
|
291
|
-
raise RuntimeError("Unsupported data type")
|
|
370
|
+
raise RuntimeError(f"Unsupported data type: {type_repr(values.dtype)}")
|
|
292
371
|
|
|
293
372
|
if host_return:
|
|
294
373
|
return int(run_count.numpy()[0])
|
|
374
|
+
return run_count
|
|
295
375
|
|
|
296
376
|
|
|
297
377
|
def array_sum(values, out=None, value_count=None, axis=None):
|
|
378
|
+
"""Compute the sum of array elements.
|
|
379
|
+
|
|
380
|
+
This function computes the sum of array elements, optionally along a specified axis.
|
|
381
|
+
The operation can be performed on the entire array or along a specific dimension.
|
|
382
|
+
|
|
383
|
+
Args:
|
|
384
|
+
values (wp.array): Input array to sum. Must be of type float32 or float64.
|
|
385
|
+
out (wp.array, optional): Output array to store results. If None, a new array is created.
|
|
386
|
+
value_count (int, optional): Number of elements to process. If None, processes entire array.
|
|
387
|
+
axis (int, optional): Axis along which to compute sum. If None, computes sum of all elements.
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
wp.array or float: The sum result. Returns a float if axis is None and out is None,
|
|
391
|
+
otherwise returns the output array.
|
|
392
|
+
|
|
393
|
+
Raises:
|
|
394
|
+
RuntimeError: If output array storage device or data type is incompatible with input array.
|
|
395
|
+
"""
|
|
298
396
|
if value_count is None:
|
|
299
397
|
if axis is None:
|
|
300
398
|
value_count = values.size
|
|
@@ -310,7 +408,7 @@ def array_sum(values, out=None, value_count=None, axis=None):
|
|
|
310
408
|
|
|
311
409
|
output_shape = tuple(output_dim(ax, dim) for ax, dim in enumerate(values.shape))
|
|
312
410
|
|
|
313
|
-
|
|
411
|
+
type_size = wp.types.type_size(values.dtype)
|
|
314
412
|
scalar_type = wp.types.type_scalar_type(values.dtype)
|
|
315
413
|
|
|
316
414
|
# User can provide a device output array for storing the number of runs
|
|
@@ -341,48 +439,67 @@ def array_sum(values, out=None, value_count=None, axis=None):
|
|
|
341
439
|
elif scalar_type == wp.float64:
|
|
342
440
|
native_func = runtime.core.array_sum_double_host
|
|
343
441
|
else:
|
|
344
|
-
raise RuntimeError("Unsupported data type")
|
|
442
|
+
raise RuntimeError(f"Unsupported data type: {type_repr(values.dtype)}")
|
|
345
443
|
elif values.device.is_cuda:
|
|
346
444
|
if scalar_type == wp.float32:
|
|
347
445
|
native_func = runtime.core.array_sum_float_device
|
|
348
446
|
elif scalar_type == wp.float64:
|
|
349
447
|
native_func = runtime.core.array_sum_double_device
|
|
350
448
|
else:
|
|
351
|
-
raise RuntimeError("Unsupported data type")
|
|
449
|
+
raise RuntimeError(f"Unsupported data type: {type_repr(values.dtype)}")
|
|
352
450
|
|
|
353
451
|
if axis is None:
|
|
354
452
|
stride = wp.types.type_size_in_bytes(values.dtype)
|
|
355
|
-
native_func(values.ptr, out.ptr, value_count, stride,
|
|
453
|
+
native_func(values.ptr, out.ptr, value_count, stride, type_size)
|
|
356
454
|
|
|
357
455
|
if host_return:
|
|
358
456
|
return out.numpy()[0]
|
|
359
|
-
|
|
360
|
-
stride = values.strides[axis]
|
|
361
|
-
for idx in np.ndindex(output_shape):
|
|
362
|
-
out_offset = sum(i * s for i, s in zip(idx, out.strides))
|
|
363
|
-
val_offset = sum(i * s for i, s in zip(idx, values.strides))
|
|
364
|
-
|
|
365
|
-
native_func(
|
|
366
|
-
values.ptr + val_offset,
|
|
367
|
-
out.ptr + out_offset,
|
|
368
|
-
value_count,
|
|
369
|
-
stride,
|
|
370
|
-
type_length,
|
|
371
|
-
)
|
|
457
|
+
return out
|
|
372
458
|
|
|
373
|
-
|
|
374
|
-
|
|
459
|
+
stride = values.strides[axis]
|
|
460
|
+
for idx in np.ndindex(output_shape):
|
|
461
|
+
out_offset = sum(i * s for i, s in zip(idx, out.strides))
|
|
462
|
+
val_offset = sum(i * s for i, s in zip(idx, values.strides))
|
|
463
|
+
|
|
464
|
+
native_func(
|
|
465
|
+
values.ptr + val_offset,
|
|
466
|
+
out.ptr + out_offset,
|
|
467
|
+
value_count,
|
|
468
|
+
stride,
|
|
469
|
+
type_size,
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
return out
|
|
375
473
|
|
|
376
474
|
|
|
377
475
|
def array_inner(a, b, out=None, count=None, axis=None):
|
|
476
|
+
"""Compute the inner product of two arrays.
|
|
477
|
+
|
|
478
|
+
This function computes the dot product between two arrays, optionally along a specified axis.
|
|
479
|
+
The operation can be performed on the entire arrays or along a specific dimension.
|
|
480
|
+
|
|
481
|
+
Args:
|
|
482
|
+
a (wp.array): First input array.
|
|
483
|
+
b (wp.array): Second input array. Must match shape and type of a.
|
|
484
|
+
out (wp.array, optional): Output array to store results. If None, a new array is created.
|
|
485
|
+
count (int, optional): Number of elements to process. If None, processes entire arrays.
|
|
486
|
+
axis (int, optional): Axis along which to compute inner product. If None, computes on flattened arrays.
|
|
487
|
+
|
|
488
|
+
Returns:
|
|
489
|
+
wp.array or float: The inner product result. Returns a float if axis is None and out is None,
|
|
490
|
+
otherwise returns the output array.
|
|
491
|
+
|
|
492
|
+
Raises:
|
|
493
|
+
RuntimeError: If array storage devices, sizes, or data types are incompatible.
|
|
494
|
+
"""
|
|
378
495
|
if a.size != b.size:
|
|
379
|
-
raise RuntimeError("
|
|
496
|
+
raise RuntimeError(f"A and b array storage sizes do not match ({a.size} vs {b.size})")
|
|
380
497
|
|
|
381
498
|
if a.device != b.device:
|
|
382
|
-
raise RuntimeError("
|
|
499
|
+
raise RuntimeError(f"A and b array storage devices do not match ({a.device} vs {b.device})")
|
|
383
500
|
|
|
384
|
-
if a.dtype
|
|
385
|
-
raise RuntimeError("
|
|
501
|
+
if not types_equal(a.dtype, b.dtype):
|
|
502
|
+
raise RuntimeError(f"A and b array data types do not match ({type_repr(a.dtype)} vs {type_repr(b.dtype)})")
|
|
386
503
|
|
|
387
504
|
if count is None:
|
|
388
505
|
if axis is None:
|
|
@@ -399,7 +516,7 @@ def array_inner(a, b, out=None, count=None, axis=None):
|
|
|
399
516
|
|
|
400
517
|
output_shape = tuple(output_dim(ax, dim) for ax, dim in enumerate(a.shape))
|
|
401
518
|
|
|
402
|
-
|
|
519
|
+
type_size = wp.types.type_size(a.dtype)
|
|
403
520
|
scalar_type = wp.types.type_scalar_type(a.dtype)
|
|
404
521
|
|
|
405
522
|
# User can provide a device output array for storing the number of runs
|
|
@@ -430,43 +547,43 @@ def array_inner(a, b, out=None, count=None, axis=None):
|
|
|
430
547
|
elif scalar_type == wp.float64:
|
|
431
548
|
native_func = runtime.core.array_inner_double_host
|
|
432
549
|
else:
|
|
433
|
-
raise RuntimeError("Unsupported data type")
|
|
550
|
+
raise RuntimeError(f"Unsupported data type: {type_repr(a.dtype)}")
|
|
434
551
|
elif a.device.is_cuda:
|
|
435
552
|
if scalar_type == wp.float32:
|
|
436
553
|
native_func = runtime.core.array_inner_float_device
|
|
437
554
|
elif scalar_type == wp.float64:
|
|
438
555
|
native_func = runtime.core.array_inner_double_device
|
|
439
556
|
else:
|
|
440
|
-
raise RuntimeError("Unsupported data type")
|
|
557
|
+
raise RuntimeError(f"Unsupported data type: {type_repr(a.dtype)}")
|
|
441
558
|
|
|
442
559
|
if axis is None:
|
|
443
560
|
stride_a = wp.types.type_size_in_bytes(a.dtype)
|
|
444
561
|
stride_b = wp.types.type_size_in_bytes(b.dtype)
|
|
445
|
-
native_func(a.ptr, b.ptr, out.ptr, count, stride_a, stride_b,
|
|
562
|
+
native_func(a.ptr, b.ptr, out.ptr, count, stride_a, stride_b, type_size)
|
|
446
563
|
|
|
447
564
|
if host_return:
|
|
448
565
|
return out.numpy()[0]
|
|
449
|
-
|
|
450
|
-
stride_a = a.strides[axis]
|
|
451
|
-
stride_b = b.strides[axis]
|
|
452
|
-
|
|
453
|
-
for idx in np.ndindex(output_shape):
|
|
454
|
-
out_offset = sum(i * s for i, s in zip(idx, out.strides))
|
|
455
|
-
a_offset = sum(i * s for i, s in zip(idx, a.strides))
|
|
456
|
-
b_offset = sum(i * s for i, s in zip(idx, b.strides))
|
|
457
|
-
|
|
458
|
-
native_func(
|
|
459
|
-
a.ptr + a_offset,
|
|
460
|
-
b.ptr + b_offset,
|
|
461
|
-
out.ptr + out_offset,
|
|
462
|
-
count,
|
|
463
|
-
stride_a,
|
|
464
|
-
stride_b,
|
|
465
|
-
type_length,
|
|
466
|
-
)
|
|
566
|
+
return out
|
|
467
567
|
|
|
468
|
-
|
|
469
|
-
|
|
568
|
+
stride_a = a.strides[axis]
|
|
569
|
+
stride_b = b.strides[axis]
|
|
570
|
+
|
|
571
|
+
for idx in np.ndindex(output_shape):
|
|
572
|
+
out_offset = sum(i * s for i, s in zip(idx, out.strides))
|
|
573
|
+
a_offset = sum(i * s for i, s in zip(idx, a.strides))
|
|
574
|
+
b_offset = sum(i * s for i, s in zip(idx, b.strides))
|
|
575
|
+
|
|
576
|
+
native_func(
|
|
577
|
+
a.ptr + a_offset,
|
|
578
|
+
b.ptr + b_offset,
|
|
579
|
+
out.ptr + out_offset,
|
|
580
|
+
count,
|
|
581
|
+
stride_a,
|
|
582
|
+
stride_b,
|
|
583
|
+
type_size,
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
return out
|
|
470
587
|
|
|
471
588
|
|
|
472
589
|
@wp.kernel
|
|
@@ -479,8 +596,28 @@ def _array_cast_kernel(
|
|
|
479
596
|
|
|
480
597
|
|
|
481
598
|
def array_cast(in_array, out_array, count=None):
|
|
599
|
+
"""Cast elements from one array to another array with a different data type.
|
|
600
|
+
|
|
601
|
+
This function performs element-wise casting from the input array to the output array.
|
|
602
|
+
The arrays must have the same number of dimensions and data type shapes. If they don't match,
|
|
603
|
+
the arrays will be flattened and casting will be performed at the scalar level.
|
|
604
|
+
|
|
605
|
+
Args:
|
|
606
|
+
in_array (wp.array): Input array to cast from.
|
|
607
|
+
out_array (wp.array): Output array to cast to. Must have the same device as in_array.
|
|
608
|
+
count (int, optional): Number of elements to process. If None, processes entire array.
|
|
609
|
+
For multi-dimensional arrays, partial casting is not supported.
|
|
610
|
+
|
|
611
|
+
Raises:
|
|
612
|
+
RuntimeError: If arrays have different devices or if attempting partial casting
|
|
613
|
+
on multi-dimensional arrays.
|
|
614
|
+
|
|
615
|
+
Note:
|
|
616
|
+
If the input and output arrays have the same data type, this function will
|
|
617
|
+
simply copy the data without any conversion.
|
|
618
|
+
"""
|
|
482
619
|
if in_array.device != out_array.device:
|
|
483
|
-
raise RuntimeError("Array storage devices do not match")
|
|
620
|
+
raise RuntimeError(f"Array storage devices do not match ({in_array.device} vs {out_array.device})")
|
|
484
621
|
|
|
485
622
|
in_array_data_shape = getattr(in_array.dtype, "_shape_", ())
|
|
486
623
|
out_array_data_shape = getattr(out_array.dtype, "_shape_", ())
|
|
@@ -491,8 +628,8 @@ def array_cast(in_array, out_array, count=None):
|
|
|
491
628
|
in_array = in_array.flatten()
|
|
492
629
|
out_array = out_array.flatten()
|
|
493
630
|
|
|
494
|
-
in_array_data_length = warp.types.
|
|
495
|
-
out_array_data_length = warp.types.
|
|
631
|
+
in_array_data_length = warp.types.type_size(in_array.dtype)
|
|
632
|
+
out_array_data_length = warp.types.type_size(out_array.dtype)
|
|
496
633
|
in_array_scalar_type = wp.types.type_scalar_type(in_array.dtype)
|
|
497
634
|
out_array_scalar_type = wp.types.type_scalar_type(out_array.dtype)
|
|
498
635
|
|
|
@@ -534,6 +671,430 @@ def array_cast(in_array, out_array, count=None):
|
|
|
534
671
|
wp.launch(kernel=_array_cast_kernel, dim=dim, inputs=[out_array, in_array], device=out_array.device)
|
|
535
672
|
|
|
536
673
|
|
|
674
|
+
def create_warp_function(func: Callable) -> tuple[wp.Function, warp.context.Module]:
|
|
675
|
+
"""Create a Warp function from a Python function.
|
|
676
|
+
|
|
677
|
+
Args:
|
|
678
|
+
func (Callable): A Python function to be converted to a Warp function.
|
|
679
|
+
|
|
680
|
+
Returns:
|
|
681
|
+
wp.Function: A Warp function created from the input function.
|
|
682
|
+
"""
|
|
683
|
+
|
|
684
|
+
from .codegen import Adjoint, get_full_arg_spec
|
|
685
|
+
|
|
686
|
+
def unique_name(code: str):
|
|
687
|
+
return "func_" + hex(hash(code))[-8:]
|
|
688
|
+
|
|
689
|
+
# Create a Warp function from the input function
|
|
690
|
+
source = None
|
|
691
|
+
argspec = get_full_arg_spec(func)
|
|
692
|
+
key = getattr(func, "__name__", None)
|
|
693
|
+
if key is None:
|
|
694
|
+
source, _ = Adjoint.extract_function_source(func)
|
|
695
|
+
key = unique_name(source)
|
|
696
|
+
elif key == "<lambda>":
|
|
697
|
+
body = Adjoint.extract_lambda_source(func, only_body=True)
|
|
698
|
+
if body is None:
|
|
699
|
+
raise ValueError("Could not extract lambda source code")
|
|
700
|
+
key = unique_name(body)
|
|
701
|
+
source = f"def {key}({', '.join(argspec.args)}):\n return {body}"
|
|
702
|
+
else:
|
|
703
|
+
# use the qualname of the function as the key
|
|
704
|
+
key = getattr(func, "__qualname__", key)
|
|
705
|
+
key = key.replace(".", "_").replace(" ", "_").replace("<", "").replace(">", "_")
|
|
706
|
+
|
|
707
|
+
module = warp.context.get_module(f"map_{key}")
|
|
708
|
+
func = wp.Function(
|
|
709
|
+
func,
|
|
710
|
+
namespace="",
|
|
711
|
+
module=module,
|
|
712
|
+
key=key,
|
|
713
|
+
source=source,
|
|
714
|
+
overloaded_annotations=dict.fromkeys(argspec.args, Any),
|
|
715
|
+
)
|
|
716
|
+
return func, module
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
def broadcast_shapes(shapes: list[tuple[int]]) -> tuple[int]:
|
|
720
|
+
"""Broadcast a list of shapes to a common shape.
|
|
721
|
+
|
|
722
|
+
Following the broadcasting rules of NumPy, two shapes are compatible when:
|
|
723
|
+
starting from the trailing dimension,
|
|
724
|
+
1. the two dimensions are equal, or
|
|
725
|
+
2. one of the dimensions is 1.
|
|
726
|
+
|
|
727
|
+
Example:
|
|
728
|
+
>>> broadcast_shapes([(3, 1, 4), (5, 4)])
|
|
729
|
+
(3, 5, 4)
|
|
730
|
+
|
|
731
|
+
Returns:
|
|
732
|
+
tuple[int]: The broadcasted shape.
|
|
733
|
+
|
|
734
|
+
Raises:
|
|
735
|
+
ValueError: If the shapes are not broadcastable.
|
|
736
|
+
"""
|
|
737
|
+
ref = shapes[0]
|
|
738
|
+
for shape in shapes[1:]:
|
|
739
|
+
broad = []
|
|
740
|
+
for j in range(1, max(len(ref), len(shape)) + 1):
|
|
741
|
+
if j <= len(ref) and j <= len(shape):
|
|
742
|
+
s = shape[-j]
|
|
743
|
+
r = ref[-j]
|
|
744
|
+
if s == r:
|
|
745
|
+
broad.append(s)
|
|
746
|
+
elif s == 1 or r == 1:
|
|
747
|
+
broad.append(max(s, r))
|
|
748
|
+
else:
|
|
749
|
+
raise ValueError(f"Shapes {ref} and {shape} are not broadcastable")
|
|
750
|
+
elif j <= len(ref):
|
|
751
|
+
broad.append(ref[-j])
|
|
752
|
+
else:
|
|
753
|
+
broad.append(shape[-j])
|
|
754
|
+
ref = tuple(reversed(broad))
|
|
755
|
+
return ref
|
|
756
|
+
|
|
757
|
+
|
|
758
|
+
def map(
|
|
759
|
+
func: Callable | wp.Function,
|
|
760
|
+
*inputs: Array[DType] | Any,
|
|
761
|
+
out: Array[DType] | list[Array[DType]] | None = None,
|
|
762
|
+
return_kernel: bool = False,
|
|
763
|
+
block_dim=256,
|
|
764
|
+
device: Devicelike = None,
|
|
765
|
+
) -> Array[DType] | list[Array[DType]] | wp.Kernel:
|
|
766
|
+
"""
|
|
767
|
+
Map a function over the elements of one or more arrays.
|
|
768
|
+
|
|
769
|
+
You can use a Warp function, a regular Python function, or a lambda expression to map it to a set of arrays.
|
|
770
|
+
|
|
771
|
+
.. testcode::
|
|
772
|
+
|
|
773
|
+
a = wp.array([1, 2, 3], dtype=wp.float32)
|
|
774
|
+
b = wp.array([4, 5, 6], dtype=wp.float32)
|
|
775
|
+
c = wp.array([7, 8, 9], dtype=wp.float32)
|
|
776
|
+
result = wp.map(lambda x, y, z: x + 2.0 * y - z, a, b, c)
|
|
777
|
+
print(result)
|
|
778
|
+
|
|
779
|
+
.. testoutput::
|
|
780
|
+
|
|
781
|
+
[2. 4. 6.]
|
|
782
|
+
|
|
783
|
+
Clamp values in an array in place:
|
|
784
|
+
|
|
785
|
+
.. testcode::
|
|
786
|
+
|
|
787
|
+
xs = wp.array([-1.0, 0.0, 1.0], dtype=wp.float32)
|
|
788
|
+
wp.map(wp.clamp, xs, -0.5, 0.5, out=xs)
|
|
789
|
+
print(xs)
|
|
790
|
+
|
|
791
|
+
.. testoutput::
|
|
792
|
+
|
|
793
|
+
[-0.5 0. 0.5]
|
|
794
|
+
|
|
795
|
+
Note that only one of the inputs must be a Warp array. For example, it is possible
|
|
796
|
+
vectorize the function :func:`warp.transform_point` over a collection of points
|
|
797
|
+
with a given input transform as follows:
|
|
798
|
+
|
|
799
|
+
.. code-block:: python
|
|
800
|
+
|
|
801
|
+
tf = wp.transform((1.0, 2.0, 3.0), wp.quat_rpy(0.2, -0.6, 0.1))
|
|
802
|
+
points = wp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=wp.vec3)
|
|
803
|
+
transformed = wp.map(wp.transform_point, tf, points)
|
|
804
|
+
|
|
805
|
+
Besides regular Warp arrays, other array types, such as the ``indexedarray``, are supported as well:
|
|
806
|
+
|
|
807
|
+
.. testcode::
|
|
808
|
+
|
|
809
|
+
arr = wp.array(data=np.arange(10, dtype=np.float32))
|
|
810
|
+
indices = wp.array([1, 3, 5, 7, 9], dtype=int)
|
|
811
|
+
iarr = wp.indexedarray1d(arr, [indices])
|
|
812
|
+
out = wp.map(lambda x: x * 10.0, iarr)
|
|
813
|
+
print(out)
|
|
814
|
+
|
|
815
|
+
.. testoutput::
|
|
816
|
+
|
|
817
|
+
[10. 30. 50. 70. 90.]
|
|
818
|
+
|
|
819
|
+
If multiple arrays are provided, the
|
|
820
|
+
`NumPy broadcasting rules <https://numpy.org/doc/stable/user/basics.broadcasting.html>`_
|
|
821
|
+
are applied to determine the shape of the output array.
|
|
822
|
+
Two shapes are compatible when:
|
|
823
|
+
starting from the trailing dimension,
|
|
824
|
+
|
|
825
|
+
1. the two dimensions are equal, or
|
|
826
|
+
2. one of the dimensions is 1.
|
|
827
|
+
|
|
828
|
+
For example, given arrays of shapes ``(3, 1, 4)`` and ``(5, 4)``, the broadcasted
|
|
829
|
+
shape is ``(3, 5, 4)``.
|
|
830
|
+
|
|
831
|
+
If no array(s) are provided to the ``out`` argument, the output array(s) are created automatically.
|
|
832
|
+
The data type(s) of the output array(s) are determined by the type of the return value(s) of
|
|
833
|
+
the function. The ``requires_grad`` flag for an automatically created output array is set to ``True``
|
|
834
|
+
if any of the input arrays have it set to ``True`` and the respective output array's ``dtype`` is a type that
|
|
835
|
+
supports differentiation.
|
|
836
|
+
|
|
837
|
+
Args:
|
|
838
|
+
func (Callable | Function): The function to map over the arrays.
|
|
839
|
+
*inputs (array | Any): The input arrays or values to pass to the function.
|
|
840
|
+
out (array | list[array] | None): Optional output array(s) to store the result(s). If None, the output array(s) will be created automatically.
|
|
841
|
+
return_kernel (bool): If True, only return the generated kernel without performing the mapping operation.
|
|
842
|
+
block_dim (int): The block dimension for the kernel launch.
|
|
843
|
+
device (Devicelike): The device on which to run the kernel.
|
|
844
|
+
|
|
845
|
+
Returns:
|
|
846
|
+
array | list[array] | Kernel:
|
|
847
|
+
The resulting array(s) of the mapping. If ``return_kernel`` is True, only returns the kernel used for mapping.
|
|
848
|
+
"""
|
|
849
|
+
|
|
850
|
+
import builtins
|
|
851
|
+
|
|
852
|
+
from .codegen import Adjoint, Struct, StructInstance
|
|
853
|
+
from .types import (
|
|
854
|
+
is_array,
|
|
855
|
+
type_is_matrix,
|
|
856
|
+
type_is_quaternion,
|
|
857
|
+
type_is_transformation,
|
|
858
|
+
type_is_vector,
|
|
859
|
+
type_repr,
|
|
860
|
+
type_to_warp,
|
|
861
|
+
types_equal,
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
# mapping from struct name to its Python definition
|
|
865
|
+
referenced_modules: dict[str, ModuleType] = {}
|
|
866
|
+
|
|
867
|
+
def type_to_code(wp_type) -> str:
|
|
868
|
+
"""Returns the string representation of a given Warp type."""
|
|
869
|
+
if is_array(wp_type):
|
|
870
|
+
return f"warp.array(ndim={wp_type.ndim}, dtype={type_to_code(wp_type.dtype)})"
|
|
871
|
+
if isinstance(wp_type, Struct):
|
|
872
|
+
key = f"{wp_type.__module__}.{wp_type.key}"
|
|
873
|
+
module = sys.modules.get(wp_type.__module__, None)
|
|
874
|
+
if module is not None:
|
|
875
|
+
referenced_modules[wp_type.__module__] = module
|
|
876
|
+
return key
|
|
877
|
+
if type_is_transformation(wp_type):
|
|
878
|
+
return f"warp.types.transformation(dtype={type_to_code(wp_type._wp_scalar_type_)})"
|
|
879
|
+
if type_is_quaternion(wp_type):
|
|
880
|
+
return f"warp.types.quaternion(dtype={type_to_code(wp_type._wp_scalar_type_)})"
|
|
881
|
+
if type_is_vector(wp_type):
|
|
882
|
+
return f"warp.types.vector(length={wp_type._shape_[0]}, dtype={type_to_code(wp_type._wp_scalar_type_)})"
|
|
883
|
+
if type_is_matrix(wp_type):
|
|
884
|
+
return f"warp.types.matrix(shape=({wp_type._shape_[0]}, {wp_type._shape_[1]}), dtype={type_to_code(wp_type._wp_scalar_type_)})"
|
|
885
|
+
if wp_type == builtins.bool:
|
|
886
|
+
return "bool"
|
|
887
|
+
if wp_type == builtins.float:
|
|
888
|
+
return "float"
|
|
889
|
+
if wp_type == builtins.int:
|
|
890
|
+
return "int"
|
|
891
|
+
|
|
892
|
+
name = getattr(wp_type, "__name__", None)
|
|
893
|
+
if name is None:
|
|
894
|
+
return type_repr(wp_type)
|
|
895
|
+
name = getattr(wp_type, "__qualname__", name)
|
|
896
|
+
module = getattr(wp_type, "__module__", None)
|
|
897
|
+
if module is not None:
|
|
898
|
+
referenced_modules[wp_type.__module__] = module
|
|
899
|
+
return wp_type.__module__ + "." + name
|
|
900
|
+
|
|
901
|
+
def get_warp_type(value):
|
|
902
|
+
dtype = type(value)
|
|
903
|
+
if issubclass(dtype, StructInstance):
|
|
904
|
+
# a struct
|
|
905
|
+
return value._cls
|
|
906
|
+
return type_to_warp(dtype)
|
|
907
|
+
|
|
908
|
+
# gather the arrays in the inputs
|
|
909
|
+
array_shapes = [a.shape for a in inputs if is_array(a)]
|
|
910
|
+
if len(array_shapes) == 0:
|
|
911
|
+
raise ValueError("map requires at least one warp.array input")
|
|
912
|
+
# broadcast the shapes of the arrays
|
|
913
|
+
out_shape = broadcast_shapes(array_shapes)
|
|
914
|
+
|
|
915
|
+
module = None
|
|
916
|
+
out_dtypes = None
|
|
917
|
+
skip_arg_type_checks = False
|
|
918
|
+
if isinstance(func, wp.Function):
|
|
919
|
+
func_name = func.key
|
|
920
|
+
wp_func = func
|
|
921
|
+
else:
|
|
922
|
+
# check if op is a callable function
|
|
923
|
+
if not callable(func):
|
|
924
|
+
raise TypeError("func must be a callable function or a warp.Function")
|
|
925
|
+
wp_func, module = create_warp_function(func)
|
|
926
|
+
func_name = wp_func.key
|
|
927
|
+
# we created a generic function here (arg types are all Any)
|
|
928
|
+
skip_arg_type_checks = True
|
|
929
|
+
if module is None:
|
|
930
|
+
module = warp.context.get_module(f"map_{func_name}")
|
|
931
|
+
|
|
932
|
+
arg_names = list(wp_func.input_types.keys())
|
|
933
|
+
# determine output dtype
|
|
934
|
+
if wp_func.value_func is not None or wp_func.value_type is not None:
|
|
935
|
+
arg_types = {}
|
|
936
|
+
arg_values = {}
|
|
937
|
+
for i, arg_name in enumerate(arg_names):
|
|
938
|
+
if is_array(inputs[i]):
|
|
939
|
+
# we will pass an element of the array to the function
|
|
940
|
+
arg_types[arg_name] = inputs[i].dtype
|
|
941
|
+
if device is None:
|
|
942
|
+
device = inputs[i].device
|
|
943
|
+
else:
|
|
944
|
+
# we pass the input value directly to the function
|
|
945
|
+
arg_types[arg_name] = get_warp_type(inputs[i])
|
|
946
|
+
func_or_none = wp_func.get_overload(list(arg_types.values()), {})
|
|
947
|
+
if func_or_none is None:
|
|
948
|
+
raise TypeError(
|
|
949
|
+
f"Function {func_name} does not support the provided argument types {', '.join(type_repr(t) for t in arg_types.values())}"
|
|
950
|
+
)
|
|
951
|
+
func = func_or_none
|
|
952
|
+
if func.value_func is not None:
|
|
953
|
+
out_dtype = func.value_func(arg_types, arg_values)
|
|
954
|
+
else:
|
|
955
|
+
out_dtype = func.value_type
|
|
956
|
+
if isinstance(out_dtype, tuple) or isinstance(out_dtype, list):
|
|
957
|
+
out_dtypes = out_dtype
|
|
958
|
+
else:
|
|
959
|
+
out_dtypes = (out_dtype,)
|
|
960
|
+
else:
|
|
961
|
+
# try to evaluate the function to determine the output type
|
|
962
|
+
args = []
|
|
963
|
+
arg_types = wp_func.input_types
|
|
964
|
+
if len(inputs) != len(arg_types):
|
|
965
|
+
raise TypeError(
|
|
966
|
+
f"Number of input arguments ({len(inputs)}) does not match expected number of function arguments ({len(arg_types)})"
|
|
967
|
+
)
|
|
968
|
+
for (arg_name, arg_type), input in zip(arg_types.items(), inputs):
|
|
969
|
+
if is_array(input):
|
|
970
|
+
if not skip_arg_type_checks and not types_equal(input.dtype, arg_type):
|
|
971
|
+
raise TypeError(
|
|
972
|
+
f'Incorrect input provided for argument "{arg_name}": received array of dtype {type_repr(input.dtype)}, expected {type_repr(arg_type)}'
|
|
973
|
+
)
|
|
974
|
+
args.append(input.dtype())
|
|
975
|
+
if device is None:
|
|
976
|
+
device = input.device
|
|
977
|
+
else:
|
|
978
|
+
if not skip_arg_type_checks and not types_equal(type(input), arg_type):
|
|
979
|
+
raise TypeError(
|
|
980
|
+
f'Incorrect input provided for argument "{arg_name}": received {type_repr(type(input))}, expected {type_repr(arg_type)}'
|
|
981
|
+
)
|
|
982
|
+
args.append(input)
|
|
983
|
+
result = wp_func(*args)
|
|
984
|
+
if result is None:
|
|
985
|
+
raise TypeError("The provided function must return a value")
|
|
986
|
+
if isinstance(result, tuple) or isinstance(result, list):
|
|
987
|
+
out_dtypes = tuple(get_warp_type(r) for r in result)
|
|
988
|
+
else:
|
|
989
|
+
out_dtypes = (get_warp_type(result),)
|
|
990
|
+
|
|
991
|
+
if out_dtypes is None:
|
|
992
|
+
raise TypeError("Could not determine the output type of the function, make sure it returns a value")
|
|
993
|
+
|
|
994
|
+
if out is None:
|
|
995
|
+
requires_grad = any(getattr(a, "requires_grad", False) for a in inputs if is_array(a))
|
|
996
|
+
outputs = []
|
|
997
|
+
for dtype in out_dtypes:
|
|
998
|
+
rg = requires_grad and Adjoint.is_differentiable_value_type(dtype)
|
|
999
|
+
outputs.append(wp.empty(out_shape, dtype=dtype, requires_grad=rg, device=device))
|
|
1000
|
+
elif len(out_dtypes) == 1 and is_array(out):
|
|
1001
|
+
if not types_equal(out.dtype, out_dtypes[0]):
|
|
1002
|
+
raise TypeError(
|
|
1003
|
+
f"Output array dtype {type_repr(out.dtype)} does not match expected dtype {type_repr(out_dtypes[0])}"
|
|
1004
|
+
)
|
|
1005
|
+
if out.shape != out_shape:
|
|
1006
|
+
raise TypeError(f"Output array shape {out.shape} does not match expected shape {out_shape}")
|
|
1007
|
+
outputs = [out]
|
|
1008
|
+
elif len(out_dtypes) > 1:
|
|
1009
|
+
if isinstance(out, tuple) or isinstance(out, list):
|
|
1010
|
+
if len(out) != len(out_dtypes):
|
|
1011
|
+
raise TypeError(
|
|
1012
|
+
f"Number of provided output arrays ({len(out)}) does not match expected number of function outputs ({len(out_dtypes)})"
|
|
1013
|
+
)
|
|
1014
|
+
for i, a in enumerate(out):
|
|
1015
|
+
if not types_equal(a.dtype, out_dtypes[i]):
|
|
1016
|
+
raise TypeError(
|
|
1017
|
+
f"Output array {i} dtype {type_repr(a.dtype)} does not match expected dtype {type_repr(out_dtypes[i])}"
|
|
1018
|
+
)
|
|
1019
|
+
if a.shape != out_shape:
|
|
1020
|
+
raise TypeError(f"Output array {i} shape {a.shape} does not match expected shape {out_shape}")
|
|
1021
|
+
outputs = list(out)
|
|
1022
|
+
else:
|
|
1023
|
+
raise TypeError(
|
|
1024
|
+
f"Invalid output provided, expected {len(out_dtypes)} Warp arrays with shape {out_shape} and dtypes ({', '.join(type_repr(t) for t in out_dtypes)})"
|
|
1025
|
+
)
|
|
1026
|
+
|
|
1027
|
+
# create code for a kernel
|
|
1028
|
+
code = """def map_kernel({kernel_args}):
|
|
1029
|
+
{tids} = wp.tid()
|
|
1030
|
+
{load_args}
|
|
1031
|
+
"""
|
|
1032
|
+
if len(outputs) == 1:
|
|
1033
|
+
code += "__out_0[{tids}] = {func_name}({arg_names})"
|
|
1034
|
+
else:
|
|
1035
|
+
code += ", ".join(f"__o_{i}" for i in range(len(outputs)))
|
|
1036
|
+
code += " = {func_name}({arg_names})\n"
|
|
1037
|
+
for i in range(len(outputs)):
|
|
1038
|
+
code += f" __out_{i}" + "[{tids}]" + f" = __o_{i}\n"
|
|
1039
|
+
|
|
1040
|
+
tids = [f"__tid_{i}" for i in range(len(out_shape))]
|
|
1041
|
+
|
|
1042
|
+
load_args = []
|
|
1043
|
+
kernel_args = []
|
|
1044
|
+
for arg_name, input in zip(arg_names, inputs):
|
|
1045
|
+
if is_array(input):
|
|
1046
|
+
arr_name = f"{arg_name}_array"
|
|
1047
|
+
array_type_name = type(input).__name__
|
|
1048
|
+
kernel_args.append(
|
|
1049
|
+
f"{arr_name}: wp.{array_type_name}(dtype={type_to_code(input.dtype)}, ndim={input.ndim})"
|
|
1050
|
+
)
|
|
1051
|
+
shape = input.shape
|
|
1052
|
+
indices = []
|
|
1053
|
+
for i in range(1, len(shape) + 1):
|
|
1054
|
+
if shape[-i] == 1:
|
|
1055
|
+
indices.append("0")
|
|
1056
|
+
else:
|
|
1057
|
+
indices.append(tids[-i])
|
|
1058
|
+
|
|
1059
|
+
load_args.append(f"{arg_name} = {arr_name}[{', '.join(reversed(indices))}]")
|
|
1060
|
+
else:
|
|
1061
|
+
kernel_args.append(f"{arg_name}: {type_to_code(type(input))}")
|
|
1062
|
+
for i, o in enumerate(outputs):
|
|
1063
|
+
array_type_name = type(o).__name__
|
|
1064
|
+
kernel_args.append(f"__out_{i}: wp.{array_type_name}(dtype={type_to_code(o.dtype)}, ndim={o.ndim})")
|
|
1065
|
+
code = code.format(
|
|
1066
|
+
func_name=func_name,
|
|
1067
|
+
kernel_args=", ".join(kernel_args),
|
|
1068
|
+
arg_names=", ".join(arg_names),
|
|
1069
|
+
tids=", ".join(tids),
|
|
1070
|
+
load_args="\n ".join(load_args),
|
|
1071
|
+
)
|
|
1072
|
+
namespace = {}
|
|
1073
|
+
namespace.update({"wp": wp, "warp": wp, func_name: wp_func, "Any": Any})
|
|
1074
|
+
namespace.update(referenced_modules)
|
|
1075
|
+
exec(code, namespace)
|
|
1076
|
+
|
|
1077
|
+
kernel = wp.Kernel(namespace["map_kernel"], key="map_kernel", source=code, module=module)
|
|
1078
|
+
if return_kernel:
|
|
1079
|
+
return kernel
|
|
1080
|
+
|
|
1081
|
+
wp.launch(
|
|
1082
|
+
kernel,
|
|
1083
|
+
dim=out_shape,
|
|
1084
|
+
inputs=inputs,
|
|
1085
|
+
outputs=outputs,
|
|
1086
|
+
block_dim=block_dim,
|
|
1087
|
+
device=device,
|
|
1088
|
+
)
|
|
1089
|
+
|
|
1090
|
+
if len(outputs) == 1:
|
|
1091
|
+
o = outputs[0]
|
|
1092
|
+
else:
|
|
1093
|
+
o = outputs
|
|
1094
|
+
|
|
1095
|
+
return o
|
|
1096
|
+
|
|
1097
|
+
|
|
537
1098
|
# code snippet for invoking cProfile
|
|
538
1099
|
# cp = cProfile.Profile()
|
|
539
1100
|
# cp.enable()
|
|
@@ -634,7 +1195,7 @@ def mem_report(): # pragma: no cover
|
|
|
634
1195
|
element_size = tensor.storage().element_size()
|
|
635
1196
|
mem = numel * element_size / 1024 / 1024 # 32bit=4Byte, MByte
|
|
636
1197
|
total_mem += mem
|
|
637
|
-
print("Type:
|
|
1198
|
+
print(f"Type: {mem_type:<4} | Total Tensors: {total_numel:>8} | Used Memory: {total_mem:>8.2f} MB")
|
|
638
1199
|
|
|
639
1200
|
import gc
|
|
640
1201
|
|
|
@@ -712,7 +1273,7 @@ class ScopedStream:
|
|
|
712
1273
|
device (Device): The device associated with the stream.
|
|
713
1274
|
"""
|
|
714
1275
|
|
|
715
|
-
def __init__(self, stream:
|
|
1276
|
+
def __init__(self, stream: wp.Stream | None, sync_enter: bool = True, sync_exit: bool = False):
|
|
716
1277
|
"""Initializes the context manager with a stream and synchronization options.
|
|
717
1278
|
|
|
718
1279
|
Args:
|
|
@@ -765,12 +1326,12 @@ class ScopedTimer:
|
|
|
765
1326
|
active: bool = True,
|
|
766
1327
|
print: bool = True,
|
|
767
1328
|
detailed: bool = False,
|
|
768
|
-
dict:
|
|
1329
|
+
dict: dict[str, list[float]] | None = None,
|
|
769
1330
|
use_nvtx: bool = False,
|
|
770
|
-
color:
|
|
1331
|
+
color: int | str = "rapids",
|
|
771
1332
|
synchronize: bool = False,
|
|
772
1333
|
cuda_filter: int = 0,
|
|
773
|
-
report_func:
|
|
1334
|
+
report_func: Callable[[list[TimingResult], str], None] | None = None,
|
|
774
1335
|
skip_tape: bool = False,
|
|
775
1336
|
):
|
|
776
1337
|
"""Context manager object for a timer
|
|
@@ -792,7 +1353,7 @@ class ScopedTimer:
|
|
|
792
1353
|
Attributes:
|
|
793
1354
|
extra_msg (str): Can be set to a string that will be added to the printout at context exit.
|
|
794
1355
|
elapsed (float): The duration of the ``with`` block used with this object
|
|
795
|
-
timing_results (
|
|
1356
|
+
timing_results (list[TimingResult]): The list of activity timing results, if collection was requested using ``cuda_filter``
|
|
796
1357
|
"""
|
|
797
1358
|
self.name = name
|
|
798
1359
|
self.active = active and self.enabled
|
|
@@ -986,12 +1547,12 @@ def check_p2p():
|
|
|
986
1547
|
class timing_result_t(ctypes.Structure):
|
|
987
1548
|
"""CUDA timing struct for fetching values from C++"""
|
|
988
1549
|
|
|
989
|
-
_fields_ =
|
|
1550
|
+
_fields_ = (
|
|
990
1551
|
("context", ctypes.c_void_p),
|
|
991
1552
|
("name", ctypes.c_char_p),
|
|
992
1553
|
("filter", ctypes.c_int),
|
|
993
1554
|
("elapsed", ctypes.c_float),
|
|
994
|
-
|
|
1555
|
+
)
|
|
995
1556
|
|
|
996
1557
|
|
|
997
1558
|
class TimingResult:
|
|
@@ -1025,7 +1586,7 @@ def timing_begin(cuda_filter: int = TIMING_ALL, synchronize: bool = True) -> Non
|
|
|
1025
1586
|
warp.context.runtime.core.cuda_timing_begin(cuda_filter)
|
|
1026
1587
|
|
|
1027
1588
|
|
|
1028
|
-
def timing_end(synchronize: bool = True) ->
|
|
1589
|
+
def timing_end(synchronize: bool = True) -> list[TimingResult]:
|
|
1029
1590
|
"""End detailed activity timing.
|
|
1030
1591
|
|
|
1031
1592
|
Parameters:
|
|
@@ -1071,7 +1632,7 @@ def timing_end(synchronize: bool = True) -> List[TimingResult]:
|
|
|
1071
1632
|
return results
|
|
1072
1633
|
|
|
1073
1634
|
|
|
1074
|
-
def timing_print(results:
|
|
1635
|
+
def timing_print(results: list[TimingResult], indent: str = "") -> None:
|
|
1075
1636
|
"""Print timing results.
|
|
1076
1637
|
|
|
1077
1638
|
Parameters:
|