warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- docs/conf.py +17 -5
- examples/env/env_ant.py +1 -1
- examples/env/env_cartpole.py +1 -1
- examples/env/env_humanoid.py +1 -1
- examples/env/env_usd.py +4 -1
- examples/env/environment.py +8 -9
- examples/example_dem.py +34 -33
- examples/example_diffray.py +364 -337
- examples/example_fluid.py +32 -23
- examples/example_jacobian_ik.py +97 -93
- examples/example_marching_cubes.py +6 -16
- examples/example_mesh.py +6 -16
- examples/example_mesh_intersect.py +16 -14
- examples/example_nvdb.py +14 -16
- examples/example_raycast.py +14 -13
- examples/example_raymarch.py +16 -23
- examples/example_render_opengl.py +19 -10
- examples/example_sim_cartpole.py +82 -78
- examples/example_sim_cloth.py +45 -48
- examples/example_sim_fk_grad.py +51 -44
- examples/example_sim_fk_grad_torch.py +47 -40
- examples/example_sim_grad_bounce.py +108 -133
- examples/example_sim_grad_cloth.py +99 -113
- examples/example_sim_granular.py +5 -6
- examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
- examples/example_sim_neo_hookean.py +51 -55
- examples/example_sim_particle_chain.py +4 -4
- examples/example_sim_quadruped.py +126 -81
- examples/example_sim_rigid_chain.py +54 -61
- examples/example_sim_rigid_contact.py +66 -70
- examples/example_sim_rigid_fem.py +3 -3
- examples/example_sim_rigid_force.py +1 -1
- examples/example_sim_rigid_gyroscopic.py +3 -4
- examples/example_sim_rigid_kinematics.py +28 -39
- examples/example_sim_trajopt.py +112 -110
- examples/example_sph.py +9 -8
- examples/example_wave.py +7 -7
- examples/fem/bsr_utils.py +30 -17
- examples/fem/example_apic_fluid.py +85 -69
- examples/fem/example_convection_diffusion.py +97 -93
- examples/fem/example_convection_diffusion_dg.py +142 -149
- examples/fem/example_convection_diffusion_dg0.py +141 -136
- examples/fem/example_deformed_geometry.py +146 -0
- examples/fem/example_diffusion.py +115 -84
- examples/fem/example_diffusion_3d.py +116 -86
- examples/fem/example_diffusion_mgpu.py +102 -79
- examples/fem/example_mixed_elasticity.py +139 -100
- examples/fem/example_navier_stokes.py +175 -162
- examples/fem/example_stokes.py +143 -111
- examples/fem/example_stokes_transfer.py +186 -157
- examples/fem/mesh_utils.py +59 -97
- examples/fem/plot_utils.py +138 -17
- tools/ci/publishing/build_nodes_info.py +54 -0
- warp/__init__.py +4 -3
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +5 -3
- warp/build_dll.py +29 -9
- warp/builtins.py +836 -492
- warp/codegen.py +864 -553
- warp/config.py +3 -1
- warp/context.py +389 -172
- warp/fem/__init__.py +24 -6
- warp/fem/cache.py +318 -25
- warp/fem/dirichlet.py +7 -3
- warp/fem/domain.py +14 -0
- warp/fem/field/__init__.py +30 -38
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +244 -138
- warp/fem/field/restriction.py +8 -6
- warp/fem/field/test.py +127 -59
- warp/fem/field/trial.py +117 -60
- warp/fem/geometry/__init__.py +5 -1
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +24 -1
- warp/fem/geometry/geometry.py +86 -14
- warp/fem/geometry/grid_2d.py +112 -54
- warp/fem/geometry/grid_3d.py +134 -65
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +85 -33
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +451 -115
- warp/fem/geometry/trimesh_2d.py +197 -92
- warp/fem/integrate.py +534 -268
- warp/fem/operator.py +58 -31
- warp/fem/polynomial.py +11 -0
- warp/fem/quadrature/__init__.py +1 -1
- warp/fem/quadrature/pic_quadrature.py +150 -58
- warp/fem/quadrature/quadrature.py +209 -57
- warp/fem/space/__init__.py +230 -53
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +49 -2
- warp/fem/space/function_space.py +90 -39
- warp/fem/space/grid_2d_function_space.py +149 -496
- warp/fem/space/grid_3d_function_space.py +173 -538
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +129 -76
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +46 -34
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +132 -1039
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +104 -742
- warp/fem/types.py +13 -11
- warp/fem/utils.py +335 -60
- warp/native/array.h +120 -34
- warp/native/builtin.h +101 -72
- warp/native/bvh.cpp +73 -325
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +22 -40
- warp/native/clang/clang.cpp +1 -0
- warp/native/crt.h +2 -0
- warp/native/cuda_util.cpp +8 -3
- warp/native/cuda_util.h +1 -0
- warp/native/exports.h +1522 -1243
- warp/native/intersect.h +19 -4
- warp/native/intersect_adj.h +8 -8
- warp/native/mat.h +76 -17
- warp/native/mesh.cpp +33 -108
- warp/native/mesh.cu +114 -18
- warp/native/mesh.h +395 -40
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +44 -34
- warp/native/reduce.cpp +1 -1
- warp/native/sparse.cpp +4 -4
- warp/native/sparse.cu +163 -155
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +18 -14
- warp/native/vec.h +103 -21
- warp/native/warp.cpp +2 -1
- warp/native/warp.cu +28 -3
- warp/native/warp.h +4 -3
- warp/render/render_opengl.py +261 -109
- warp/sim/__init__.py +1 -2
- warp/sim/articulation.py +385 -185
- warp/sim/import_mjcf.py +59 -48
- warp/sim/import_urdf.py +15 -15
- warp/sim/import_usd.py +174 -102
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_xpbd.py +4 -3
- warp/sim/model.py +330 -250
- warp/sim/render.py +1 -1
- warp/sparse.py +625 -152
- warp/stubs.py +341 -309
- warp/tape.py +9 -6
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +94 -74
- warp/tests/test_array.py +82 -101
- warp/tests/test_array_reduce.py +57 -23
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +22 -12
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +18 -18
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +165 -134
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +75 -75
- warp/tests/test_examples.py +237 -0
- warp/tests/test_fabricarray.py +22 -24
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1034 -124
- warp/tests/test_fp16.py +23 -16
- warp/tests/test_func.py +187 -86
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +123 -181
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +35 -34
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +24 -25
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +14 -41
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +517 -2898
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +304 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +60 -22
- warp/tests/test_mesh_query_aabb.py +21 -25
- warp/tests/test_mesh_query_point.py +111 -22
- warp/tests/test_mesh_query_ray.py +12 -24
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +168 -20
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +261 -63
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +268 -63
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +90 -86
- warp/tests/test_transient_module.py +10 -12
- warp/tests/test_types.py +363 -0
- warp/tests/test_utils.py +451 -0
- warp/tests/test_vec.py +354 -2050
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +418 -376
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +291 -0
- warp/tests/unittest_utils.py +342 -0
- warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +589 -0
- warp/types.py +622 -211
- warp/utils.py +54 -393
- warp_lang-1.0.0b6.dist-info/METADATA +238 -0
- warp_lang-1.0.0b6.dist-info/RECORD +409 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
- examples/example_cache_management.py +0 -40
- examples/example_multigpu.py +0 -54
- examples/example_struct.py +0 -65
- examples/fem/example_stokes_transfer_3d.py +0 -210
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/fem/field/discrete_field.py +0 -80
- warp/fem/space/nodal_function_space.py +0 -233
- warp/tests/test_all.py +0 -223
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-1.0.0b2.dist-info/METADATA +0 -26
- warp_lang-1.0.0b2.dist-info/RECORD +0 -380
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/context.py
CHANGED
|
@@ -7,8 +7,10 @@
|
|
|
7
7
|
|
|
8
8
|
import ast
|
|
9
9
|
import ctypes
|
|
10
|
+
import gc
|
|
10
11
|
import hashlib
|
|
11
12
|
import inspect
|
|
13
|
+
import io
|
|
12
14
|
import os
|
|
13
15
|
import platform
|
|
14
16
|
import sys
|
|
@@ -68,6 +70,8 @@ class Function:
|
|
|
68
70
|
native_func=None,
|
|
69
71
|
defaults=None,
|
|
70
72
|
custom_replay_func=None,
|
|
73
|
+
native_snippet=None,
|
|
74
|
+
adj_native_snippet=None,
|
|
71
75
|
skip_forward_codegen=False,
|
|
72
76
|
skip_reverse_codegen=False,
|
|
73
77
|
custom_reverse_num_input_args=-1,
|
|
@@ -75,6 +79,7 @@ class Function:
|
|
|
75
79
|
overloaded_annotations=None,
|
|
76
80
|
code_transformers=[],
|
|
77
81
|
skip_adding_overload=False,
|
|
82
|
+
require_original_output_arg=False,
|
|
78
83
|
):
|
|
79
84
|
self.func = func # points to Python function decorated with @wp.func, may be None for builtins
|
|
80
85
|
self.key = key
|
|
@@ -90,7 +95,10 @@ class Function:
|
|
|
90
95
|
self.defaults = defaults
|
|
91
96
|
# Function instance for a custom implementation of the replay pass
|
|
92
97
|
self.custom_replay_func = custom_replay_func
|
|
98
|
+
self.native_snippet = native_snippet
|
|
99
|
+
self.adj_native_snippet = adj_native_snippet
|
|
93
100
|
self.custom_grad_func = None
|
|
101
|
+
self.require_original_output_arg = require_original_output_arg
|
|
94
102
|
|
|
95
103
|
if initializer_list_func is None:
|
|
96
104
|
self.initializer_list_func = lambda x, y: False
|
|
@@ -170,121 +178,24 @@ class Function:
|
|
|
170
178
|
# from within a kernel (experimental).
|
|
171
179
|
|
|
172
180
|
if self.is_builtin() and self.mangled_name:
|
|
173
|
-
#
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
for
|
|
177
|
-
if
|
|
181
|
+
# For each of this function's existing overloads, we attempt to pack
|
|
182
|
+
# the given arguments into the C types expected by the corresponding
|
|
183
|
+
# parameters, and we rinse and repeat until we get a match.
|
|
184
|
+
for overload in self.overloads:
|
|
185
|
+
if overload.generic:
|
|
178
186
|
continue
|
|
179
187
|
|
|
180
|
-
|
|
181
|
-
if
|
|
182
|
-
|
|
183
|
-
f"Couldn't find function {self.key} with mangled name {f.mangled_name} in the Warp native library"
|
|
184
|
-
)
|
|
185
|
-
|
|
186
|
-
try:
|
|
187
|
-
# try and pack args into what the function expects
|
|
188
|
-
params = []
|
|
189
|
-
for i, (arg_name, arg_type) in enumerate(f.input_types.items()):
|
|
190
|
-
a = args[i]
|
|
191
|
-
|
|
192
|
-
# try to convert to a value type (vec3, mat33, etc)
|
|
193
|
-
if issubclass(arg_type, ctypes.Array):
|
|
194
|
-
# wrap the arg_type (which is an ctypes.Array) in a structure
|
|
195
|
-
# to ensure parameter is passed to the .dll by value rather than reference
|
|
196
|
-
class ValueArg(ctypes.Structure):
|
|
197
|
-
_fields_ = [("value", arg_type)]
|
|
198
|
-
|
|
199
|
-
x = ValueArg()
|
|
200
|
-
|
|
201
|
-
# force conversion to ndarray first (handles tuple / list, Gf.Vec3 case)
|
|
202
|
-
if isinstance(a, ctypes.Array) is False:
|
|
203
|
-
# assume you want the float32 version of the function so it doesn't just
|
|
204
|
-
# grab an override for a random data type:
|
|
205
|
-
if arg_type._type_ != ctypes.c_float:
|
|
206
|
-
raise RuntimeError(
|
|
207
|
-
f"Error calling function '{f.key}', parameter for argument '{arg_name}' does not have c_float type."
|
|
208
|
-
)
|
|
209
|
-
|
|
210
|
-
a = np.array(a)
|
|
211
|
-
|
|
212
|
-
# flatten to 1D array
|
|
213
|
-
v = a.flatten()
|
|
214
|
-
if len(v) != arg_type._length_:
|
|
215
|
-
raise RuntimeError(
|
|
216
|
-
f"Error calling function '{f.key}', parameter for argument '{arg_name}' has length {len(v)}, but expected {arg_type._length_}. Could not convert parameter to {arg_type}."
|
|
217
|
-
)
|
|
218
|
-
|
|
219
|
-
for i in range(arg_type._length_):
|
|
220
|
-
x.value[i] = v[i]
|
|
221
|
-
|
|
222
|
-
else:
|
|
223
|
-
# already a built-in type, check it matches
|
|
224
|
-
if not warp.types.types_equal(type(a), arg_type):
|
|
225
|
-
raise RuntimeError(
|
|
226
|
-
f"Error calling function '{f.key}', parameter for argument '{arg_name}' has type '{type(a)}' but expected '{arg_type}'"
|
|
227
|
-
)
|
|
228
|
-
|
|
229
|
-
x.value = a
|
|
230
|
-
|
|
231
|
-
params.append(x)
|
|
232
|
-
|
|
233
|
-
else:
|
|
234
|
-
try:
|
|
235
|
-
# try to pack as a scalar type
|
|
236
|
-
params.append(arg_type._type_(a))
|
|
237
|
-
except Exception:
|
|
238
|
-
raise RuntimeError(
|
|
239
|
-
f"Error calling function {f.key}, unable to pack function parameter type {type(a)} for param {arg_name}, expected {arg_type}"
|
|
240
|
-
)
|
|
241
|
-
|
|
242
|
-
# returns the corresponding ctype for a scalar or vector warp type
|
|
243
|
-
def type_ctype(dtype):
|
|
244
|
-
if dtype == float:
|
|
245
|
-
return ctypes.c_float
|
|
246
|
-
elif dtype == int:
|
|
247
|
-
return ctypes.c_int32
|
|
248
|
-
elif issubclass(dtype, ctypes.Array):
|
|
249
|
-
return dtype
|
|
250
|
-
elif issubclass(dtype, ctypes.Structure):
|
|
251
|
-
return dtype
|
|
252
|
-
else:
|
|
253
|
-
# scalar type
|
|
254
|
-
return dtype._type_
|
|
255
|
-
|
|
256
|
-
value_type = type_ctype(f.value_func(None, None, None))
|
|
257
|
-
|
|
258
|
-
# construct return value (passed by address)
|
|
259
|
-
ret = value_type()
|
|
260
|
-
ret_addr = ctypes.c_void_p(ctypes.addressof(ret))
|
|
261
|
-
|
|
262
|
-
params.append(ret_addr)
|
|
263
|
-
|
|
264
|
-
c_func = getattr(warp.context.runtime.core, f.mangled_name)
|
|
265
|
-
c_func(*params)
|
|
266
|
-
|
|
267
|
-
if issubclass(value_type, ctypes.Array) or issubclass(value_type, ctypes.Structure):
|
|
268
|
-
# return vector types as ctypes
|
|
269
|
-
return ret
|
|
270
|
-
else:
|
|
271
|
-
# return scalar types as int/float
|
|
272
|
-
return ret.value
|
|
273
|
-
|
|
274
|
-
except Exception as e:
|
|
275
|
-
# couldn't pack values to match this overload
|
|
276
|
-
# store error and move onto the next one
|
|
277
|
-
error = e
|
|
278
|
-
continue
|
|
188
|
+
success, return_value = call_builtin(overload, *args)
|
|
189
|
+
if success:
|
|
190
|
+
return return_value
|
|
279
191
|
|
|
280
192
|
# overload resolution or call failed
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
raise RuntimeError(f"Error calling function '{f.key}'.")
|
|
193
|
+
raise RuntimeError(
|
|
194
|
+
f"Couldn't find a function '{self.key}' compatible with "
|
|
195
|
+
f"the arguments '{', '.join(type(x).__name__ for x in args)}'"
|
|
196
|
+
)
|
|
286
197
|
|
|
287
|
-
|
|
198
|
+
if hasattr(self, "user_overloads") and len(self.user_overloads):
|
|
288
199
|
# user-defined function with overloads
|
|
289
200
|
|
|
290
201
|
if len(kwargs):
|
|
@@ -293,28 +204,26 @@ class Function:
|
|
|
293
204
|
)
|
|
294
205
|
|
|
295
206
|
# try and find a matching overload
|
|
296
|
-
for
|
|
297
|
-
if len(
|
|
207
|
+
for overload in self.user_overloads.values():
|
|
208
|
+
if len(overload.input_types) != len(args):
|
|
298
209
|
continue
|
|
299
|
-
template_types = list(
|
|
300
|
-
arg_names = list(
|
|
210
|
+
template_types = list(overload.input_types.values())
|
|
211
|
+
arg_names = list(overload.input_types.keys())
|
|
301
212
|
try:
|
|
302
213
|
# attempt to unify argument types with function template types
|
|
303
214
|
warp.types.infer_argument_types(args, template_types, arg_names)
|
|
304
|
-
return
|
|
215
|
+
return overload.func(*args)
|
|
305
216
|
except Exception:
|
|
306
217
|
continue
|
|
307
218
|
|
|
308
219
|
raise RuntimeError(f"Error calling function '{self.key}', no overload found for arguments {args}")
|
|
309
220
|
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
if self.func is None:
|
|
314
|
-
raise RuntimeError(f"Error calling function '{self.key}', function is undefined")
|
|
221
|
+
# user-defined function with no overloads
|
|
222
|
+
if self.func is None:
|
|
223
|
+
raise RuntimeError(f"Error calling function '{self.key}', function is undefined")
|
|
315
224
|
|
|
316
|
-
|
|
317
|
-
|
|
225
|
+
# this function has no overloads, call it like a plain Python function
|
|
226
|
+
return self.func(*args, **kwargs)
|
|
318
227
|
|
|
319
228
|
def is_builtin(self):
|
|
320
229
|
return self.func is None
|
|
@@ -427,10 +336,188 @@ class Function:
|
|
|
427
336
|
return None
|
|
428
337
|
|
|
429
338
|
def __repr__(self):
|
|
430
|
-
inputs_str = ", ".join([f"{k}: {v
|
|
339
|
+
inputs_str = ", ".join([f"{k}: {warp.types.type_repr(v)}" for k, v in self.input_types.items()])
|
|
431
340
|
return f"<Function {self.key}({inputs_str})>"
|
|
432
341
|
|
|
433
342
|
|
|
343
|
+
def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
|
|
344
|
+
uses_non_warp_array_type = False
|
|
345
|
+
|
|
346
|
+
# Retrieve the built-in function from Warp's dll.
|
|
347
|
+
c_func = getattr(warp.context.runtime.core, func.mangled_name)
|
|
348
|
+
|
|
349
|
+
# Try gathering the parameters that the function expects and pack them
|
|
350
|
+
# into their corresponding C types.
|
|
351
|
+
c_params = []
|
|
352
|
+
for i, (_, arg_type) in enumerate(func.input_types.items()):
|
|
353
|
+
param = params[i]
|
|
354
|
+
|
|
355
|
+
try:
|
|
356
|
+
iter(param)
|
|
357
|
+
except TypeError:
|
|
358
|
+
is_array = False
|
|
359
|
+
else:
|
|
360
|
+
is_array = True
|
|
361
|
+
|
|
362
|
+
if is_array:
|
|
363
|
+
if not issubclass(arg_type, ctypes.Array):
|
|
364
|
+
return (False, None)
|
|
365
|
+
|
|
366
|
+
# The argument expects a built-in Warp type like a vector or a matrix.
|
|
367
|
+
|
|
368
|
+
c_param = None
|
|
369
|
+
|
|
370
|
+
if isinstance(param, ctypes.Array):
|
|
371
|
+
# The given parameter is also a built-in Warp type, so we only need
|
|
372
|
+
# to make sure that it matches with the argument.
|
|
373
|
+
if not warp.types.types_equal(type(param), arg_type):
|
|
374
|
+
return (False, None)
|
|
375
|
+
|
|
376
|
+
if isinstance(param, arg_type):
|
|
377
|
+
c_param = param
|
|
378
|
+
else:
|
|
379
|
+
# Cast the value to its argument type to make sure that it
|
|
380
|
+
# can be assigned to the field of the `Param` struct.
|
|
381
|
+
# This could error otherwise when, for example, the field type
|
|
382
|
+
# is set to `vec3i` while the value is of type `vector(length=3, dtype=int)`,
|
|
383
|
+
# even though both types are semantically identical.
|
|
384
|
+
c_param = arg_type(param)
|
|
385
|
+
else:
|
|
386
|
+
# Flatten the parameter values into a flat 1-D array.
|
|
387
|
+
arr = []
|
|
388
|
+
ndim = 1
|
|
389
|
+
stack = [(0, param)]
|
|
390
|
+
while stack:
|
|
391
|
+
depth, elem = stack.pop(0)
|
|
392
|
+
try:
|
|
393
|
+
# If `elem` is a sequence, then it should be possible
|
|
394
|
+
# to add its elements to the stack for later processing.
|
|
395
|
+
stack.extend((depth + 1, x) for x in elem)
|
|
396
|
+
except TypeError:
|
|
397
|
+
# Since `elem` doesn't seem to be a sequence,
|
|
398
|
+
# we must have a leaf value that we need to add to our
|
|
399
|
+
# resulting array.
|
|
400
|
+
arr.append(elem)
|
|
401
|
+
ndim = max(depth, ndim)
|
|
402
|
+
|
|
403
|
+
assert ndim > 0
|
|
404
|
+
|
|
405
|
+
# Ensure that if the given parameter value is, say, a 2-D array,
|
|
406
|
+
# then we try to resolve it against a matrix argument rather than
|
|
407
|
+
# a vector.
|
|
408
|
+
if ndim > len(arg_type._shape_):
|
|
409
|
+
return (False, None)
|
|
410
|
+
|
|
411
|
+
elem_count = len(arr)
|
|
412
|
+
if elem_count != arg_type._length_:
|
|
413
|
+
return (False, None)
|
|
414
|
+
|
|
415
|
+
# Retrieve the element type of the sequence while ensuring
|
|
416
|
+
# that it's homogeneous.
|
|
417
|
+
elem_type = type(arr[0])
|
|
418
|
+
for i in range(1, elem_count):
|
|
419
|
+
if type(arr[i]) is not elem_type:
|
|
420
|
+
raise ValueError("All array elements must share the same type.")
|
|
421
|
+
|
|
422
|
+
expected_elem_type = arg_type._wp_scalar_type_
|
|
423
|
+
if not (
|
|
424
|
+
elem_type is expected_elem_type
|
|
425
|
+
or (elem_type is float and expected_elem_type is warp.types.float32)
|
|
426
|
+
or (elem_type is int and expected_elem_type is warp.types.int32)
|
|
427
|
+
or (
|
|
428
|
+
issubclass(elem_type, np.number)
|
|
429
|
+
and warp.types.np_dtype_to_warp_type[np.dtype(elem_type)] is expected_elem_type
|
|
430
|
+
)
|
|
431
|
+
):
|
|
432
|
+
# The parameter value has a type not matching the type defined
|
|
433
|
+
# for the corresponding argument.
|
|
434
|
+
return (False, None)
|
|
435
|
+
|
|
436
|
+
if elem_type in warp.types.int_types:
|
|
437
|
+
# Pass the value through the expected integer type
|
|
438
|
+
# in order to evaluate any integer wrapping.
|
|
439
|
+
# For example `uint8(-1)` should result in the value `-255`.
|
|
440
|
+
arr = tuple(elem_type._type_(x.value).value for x in arr)
|
|
441
|
+
elif elem_type in warp.types.float_types:
|
|
442
|
+
# Extract the floating-point values.
|
|
443
|
+
arr = tuple(x.value for x in arr)
|
|
444
|
+
|
|
445
|
+
c_param = arg_type()
|
|
446
|
+
if warp.types.type_is_matrix(arg_type):
|
|
447
|
+
rows, cols = arg_type._shape_
|
|
448
|
+
for i in range(rows):
|
|
449
|
+
idx_start = i * cols
|
|
450
|
+
idx_end = idx_start + cols
|
|
451
|
+
c_param[i] = arr[idx_start:idx_end]
|
|
452
|
+
else:
|
|
453
|
+
c_param[:] = arr
|
|
454
|
+
|
|
455
|
+
uses_non_warp_array_type = True
|
|
456
|
+
|
|
457
|
+
c_params.append(ctypes.byref(c_param))
|
|
458
|
+
else:
|
|
459
|
+
if issubclass(arg_type, ctypes.Array):
|
|
460
|
+
return (False, None)
|
|
461
|
+
|
|
462
|
+
if not (
|
|
463
|
+
isinstance(param, arg_type)
|
|
464
|
+
or (type(param) is float and arg_type is warp.types.float32)
|
|
465
|
+
or (type(param) is int and arg_type is warp.types.int32)
|
|
466
|
+
or warp.types.np_dtype_to_warp_type.get(getattr(param, "dtype", None)) is arg_type
|
|
467
|
+
):
|
|
468
|
+
return (False, None)
|
|
469
|
+
|
|
470
|
+
if type(param) in warp.types.scalar_types:
|
|
471
|
+
param = param.value
|
|
472
|
+
|
|
473
|
+
# try to pack as a scalar type
|
|
474
|
+
if arg_type == warp.types.float16:
|
|
475
|
+
c_params.append(arg_type._type_(warp.types.float_to_half_bits(param)))
|
|
476
|
+
else:
|
|
477
|
+
c_params.append(arg_type._type_(param))
|
|
478
|
+
|
|
479
|
+
# returns the corresponding ctype for a scalar or vector warp type
|
|
480
|
+
value_type = func.value_func(None, None, None)
|
|
481
|
+
if value_type == float:
|
|
482
|
+
value_ctype = ctypes.c_float
|
|
483
|
+
elif value_type == int:
|
|
484
|
+
value_ctype = ctypes.c_int32
|
|
485
|
+
elif issubclass(value_type, (ctypes.Array, ctypes.Structure)):
|
|
486
|
+
value_ctype = value_type
|
|
487
|
+
else:
|
|
488
|
+
# scalar type
|
|
489
|
+
value_ctype = value_type._type_
|
|
490
|
+
|
|
491
|
+
# construct return value (passed by address)
|
|
492
|
+
ret = value_ctype()
|
|
493
|
+
ret_addr = ctypes.c_void_p(ctypes.addressof(ret))
|
|
494
|
+
c_params.append(ret_addr)
|
|
495
|
+
|
|
496
|
+
# Call the built-in function from Warp's dll.
|
|
497
|
+
c_func(*c_params)
|
|
498
|
+
|
|
499
|
+
# TODO: uncomment when we have a way to print warning messages only once.
|
|
500
|
+
# if uses_non_warp_array_type:
|
|
501
|
+
# warp.utils.warn(
|
|
502
|
+
# "Support for built-in functions called with non-Warp array types, "
|
|
503
|
+
# "such as lists, tuples, NumPy arrays, and others, will be dropped "
|
|
504
|
+
# "in the future. Use a Warp type such as `wp.vec`, `wp.mat`, "
|
|
505
|
+
# "`wp.quat`, or `wp.transform`.",
|
|
506
|
+
# DeprecationWarning,
|
|
507
|
+
# stacklevel=3
|
|
508
|
+
# )
|
|
509
|
+
|
|
510
|
+
if issubclass(value_ctype, ctypes.Array) or issubclass(value_ctype, ctypes.Structure):
|
|
511
|
+
# return vector types as ctypes
|
|
512
|
+
return (True, ret)
|
|
513
|
+
|
|
514
|
+
if value_type == warp.types.float16:
|
|
515
|
+
return (True, warp.types.half_bits_to_float(ret.value))
|
|
516
|
+
|
|
517
|
+
# return scalar types as int/float
|
|
518
|
+
return (True, ret.value)
|
|
519
|
+
|
|
520
|
+
|
|
434
521
|
class KernelHooks:
|
|
435
522
|
def __init__(self, forward, backward):
|
|
436
523
|
self.forward = forward
|
|
@@ -439,10 +526,20 @@ class KernelHooks:
|
|
|
439
526
|
|
|
440
527
|
# caches source and compiled entry points for a kernel (will be populated after module loads)
|
|
441
528
|
class Kernel:
|
|
442
|
-
def __init__(self, func, key, module, options=None, code_transformers=[]):
|
|
529
|
+
def __init__(self, func, key=None, module=None, options=None, code_transformers=[]):
|
|
443
530
|
self.func = func
|
|
444
|
-
|
|
445
|
-
|
|
531
|
+
|
|
532
|
+
if module is None:
|
|
533
|
+
self.module = get_module(func.__module__)
|
|
534
|
+
else:
|
|
535
|
+
self.module = module
|
|
536
|
+
|
|
537
|
+
if key is None:
|
|
538
|
+
unique_key = self.module.generate_unique_kernel_key(func.__name__)
|
|
539
|
+
self.key = unique_key
|
|
540
|
+
else:
|
|
541
|
+
self.key = key
|
|
542
|
+
|
|
446
543
|
self.options = {} if options is None else options
|
|
447
544
|
|
|
448
545
|
self.adj = warp.codegen.Adjoint(func, transformers=code_transformers)
|
|
@@ -463,8 +560,8 @@ class Kernel:
|
|
|
463
560
|
# argument indices by name
|
|
464
561
|
self.arg_indices = dict((a.label, i) for i, a in enumerate(self.adj.args))
|
|
465
562
|
|
|
466
|
-
if module:
|
|
467
|
-
module.register_kernel(self)
|
|
563
|
+
if self.module:
|
|
564
|
+
self.module.register_kernel(self)
|
|
468
565
|
|
|
469
566
|
def infer_argument_types(self, args):
|
|
470
567
|
template_types = list(self.adj.arg_types.values())
|
|
@@ -541,7 +638,7 @@ def func(f):
|
|
|
541
638
|
name = warp.codegen.make_full_qualified_name(f)
|
|
542
639
|
|
|
543
640
|
m = get_module(f.__module__)
|
|
544
|
-
|
|
641
|
+
Function(
|
|
545
642
|
func=f, key=name, namespace="", module=m, value_func=None
|
|
546
643
|
) # value_type not known yet, will be inferred during Adjoint.build()
|
|
547
644
|
|
|
@@ -549,6 +646,24 @@ def func(f):
|
|
|
549
646
|
return m.functions[name]
|
|
550
647
|
|
|
551
648
|
|
|
649
|
+
def func_native(snippet, adj_snippet=None):
|
|
650
|
+
"""
|
|
651
|
+
Decorator to register native code snippet, @func_native
|
|
652
|
+
"""
|
|
653
|
+
|
|
654
|
+
def snippet_func(f):
|
|
655
|
+
name = warp.codegen.make_full_qualified_name(f)
|
|
656
|
+
|
|
657
|
+
m = get_module(f.__module__)
|
|
658
|
+
func = Function(
|
|
659
|
+
func=f, key=name, namespace="", module=m, native_snippet=snippet, adj_native_snippet=adj_snippet
|
|
660
|
+
) # cuda snippets do not have a return value_type
|
|
661
|
+
|
|
662
|
+
return m.functions[name]
|
|
663
|
+
|
|
664
|
+
return snippet_func
|
|
665
|
+
|
|
666
|
+
|
|
552
667
|
def func_grad(forward_fn):
|
|
553
668
|
"""
|
|
554
669
|
Decorator to register a custom gradient function for a given forward function.
|
|
@@ -819,6 +934,7 @@ def add_builtin(
|
|
|
819
934
|
missing_grad=False,
|
|
820
935
|
native_func=None,
|
|
821
936
|
defaults=None,
|
|
937
|
+
require_original_output_arg=False,
|
|
822
938
|
):
|
|
823
939
|
# wrap simple single-type functions with a value_func()
|
|
824
940
|
if value_func is None:
|
|
@@ -912,7 +1028,7 @@ def add_builtin(
|
|
|
912
1028
|
# on the generated argument list and skip generation if it fails.
|
|
913
1029
|
# This also gives us the return type, which we keep for later:
|
|
914
1030
|
try:
|
|
915
|
-
return_type = value_func(
|
|
1031
|
+
return_type = value_func(argtypes, {}, [])
|
|
916
1032
|
except Exception:
|
|
917
1033
|
continue
|
|
918
1034
|
|
|
@@ -943,6 +1059,7 @@ def add_builtin(
|
|
|
943
1059
|
hidden=True,
|
|
944
1060
|
skip_replay=skip_replay,
|
|
945
1061
|
missing_grad=missing_grad,
|
|
1062
|
+
require_original_output_arg=require_original_output_arg,
|
|
946
1063
|
)
|
|
947
1064
|
|
|
948
1065
|
func = Function(
|
|
@@ -963,6 +1080,7 @@ def add_builtin(
|
|
|
963
1080
|
generic=generic,
|
|
964
1081
|
native_func=native_func,
|
|
965
1082
|
defaults=defaults,
|
|
1083
|
+
require_original_output_arg=require_original_output_arg,
|
|
966
1084
|
)
|
|
967
1085
|
|
|
968
1086
|
if key in builtin_functions:
|
|
@@ -972,7 +1090,7 @@ def add_builtin(
|
|
|
972
1090
|
|
|
973
1091
|
# export means the function will be added to the `warp` module namespace
|
|
974
1092
|
# so that users can call it directly from the Python interpreter
|
|
975
|
-
if export
|
|
1093
|
+
if export:
|
|
976
1094
|
if hasattr(warp, key):
|
|
977
1095
|
# check that we haven't already created something at this location
|
|
978
1096
|
# if it's just an overload stub for auto-complete then overwrite it
|
|
@@ -1057,8 +1175,7 @@ class ModuleBuilder:
|
|
|
1057
1175
|
while stack:
|
|
1058
1176
|
s = stack.pop()
|
|
1059
1177
|
|
|
1060
|
-
|
|
1061
|
-
structs.append(s)
|
|
1178
|
+
structs.append(s)
|
|
1062
1179
|
|
|
1063
1180
|
for var in s.vars.values():
|
|
1064
1181
|
if isinstance(var.type, warp.codegen.Struct):
|
|
@@ -1090,7 +1207,7 @@ class ModuleBuilder:
|
|
|
1090
1207
|
if not func.value_func:
|
|
1091
1208
|
|
|
1092
1209
|
def wrap(adj):
|
|
1093
|
-
def value_type(
|
|
1210
|
+
def value_type(arg_types, kwds, templates):
|
|
1094
1211
|
if adj.return_var is None or len(adj.return_var) == 0:
|
|
1095
1212
|
return None
|
|
1096
1213
|
if len(adj.return_var) == 1:
|
|
@@ -1114,9 +1231,14 @@ class ModuleBuilder:
|
|
|
1114
1231
|
|
|
1115
1232
|
# code-gen all imported functions
|
|
1116
1233
|
for func in self.functions.keys():
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1234
|
+
if func.native_snippet is None:
|
|
1235
|
+
source += warp.codegen.codegen_func(
|
|
1236
|
+
func.adj, c_func_name=func.native_func, device=device, options=self.options
|
|
1237
|
+
)
|
|
1238
|
+
else:
|
|
1239
|
+
source += warp.codegen.codegen_snippet(
|
|
1240
|
+
func.adj, name=func.key, snippet=func.native_snippet, adj_snippet=func.adj_native_snippet
|
|
1241
|
+
)
|
|
1120
1242
|
|
|
1121
1243
|
for kernel in self.module.kernels.values():
|
|
1122
1244
|
# each kernel gets an entry point in the module
|
|
@@ -1196,6 +1318,10 @@ class Module:
|
|
|
1196
1318
|
|
|
1197
1319
|
self.content_hash = None
|
|
1198
1320
|
|
|
1321
|
+
# number of times module auto-generates kernel key for user
|
|
1322
|
+
# used to ensure unique kernel keys
|
|
1323
|
+
self.count = 0
|
|
1324
|
+
|
|
1199
1325
|
def register_struct(self, struct):
|
|
1200
1326
|
self.structs[struct.key] = struct
|
|
1201
1327
|
|
|
@@ -1238,6 +1364,11 @@ class Module:
|
|
|
1238
1364
|
# for a reload of module on next launch
|
|
1239
1365
|
self.unload()
|
|
1240
1366
|
|
|
1367
|
+
def generate_unique_kernel_key(self, key):
|
|
1368
|
+
unique_key = f"{key}_{self.count}"
|
|
1369
|
+
self.count += 1
|
|
1370
|
+
return unique_key
|
|
1371
|
+
|
|
1241
1372
|
# collect all referenced functions / structs
|
|
1242
1373
|
# given the AST of a function or kernel
|
|
1243
1374
|
def find_references(self, adj):
|
|
@@ -1251,7 +1382,7 @@ class Module:
|
|
|
1251
1382
|
if isinstance(node, ast.Call):
|
|
1252
1383
|
try:
|
|
1253
1384
|
# try to resolve the function
|
|
1254
|
-
func, _ = adj.
|
|
1385
|
+
func, _ = adj.resolve_static_expression(node.func, eval_types=False)
|
|
1255
1386
|
|
|
1256
1387
|
# if this is a user-defined function, add a module reference
|
|
1257
1388
|
if isinstance(func, warp.context.Function) and func.module is not None:
|
|
@@ -1304,9 +1435,24 @@ class Module:
|
|
|
1304
1435
|
s = func.adj.source
|
|
1305
1436
|
ch.update(bytes(s, "utf-8"))
|
|
1306
1437
|
|
|
1438
|
+
if func.custom_grad_func:
|
|
1439
|
+
s = func.custom_grad_func.adj.source
|
|
1440
|
+
ch.update(bytes(s, "utf-8"))
|
|
1441
|
+
if func.custom_replay_func:
|
|
1442
|
+
s = func.custom_replay_func.adj.source
|
|
1443
|
+
|
|
1444
|
+
# cache func arg types
|
|
1445
|
+
for arg, arg_type in func.adj.arg_types.items():
|
|
1446
|
+
s = f"{arg}: {get_type_name(arg_type)}"
|
|
1447
|
+
ch.update(bytes(s, "utf-8"))
|
|
1448
|
+
|
|
1307
1449
|
# kernel source
|
|
1308
1450
|
for kernel in module.kernels.values():
|
|
1309
1451
|
ch.update(bytes(kernel.adj.source, "utf-8"))
|
|
1452
|
+
# cache kernel arg types
|
|
1453
|
+
for arg, arg_type in kernel.adj.arg_types.items():
|
|
1454
|
+
s = f"{arg}: {get_type_name(arg_type)}"
|
|
1455
|
+
ch.update(bytes(s, "utf-8"))
|
|
1310
1456
|
# for generic kernels the Python source is always the same,
|
|
1311
1457
|
# but we hash the type signatures of all the overloads
|
|
1312
1458
|
if kernel.is_generic:
|
|
@@ -1605,13 +1751,13 @@ class ContextGuard:
|
|
|
1605
1751
|
def __enter__(self):
|
|
1606
1752
|
if self.device.is_cuda:
|
|
1607
1753
|
runtime.core.cuda_context_push_current(self.device.context)
|
|
1608
|
-
elif
|
|
1754
|
+
elif is_cuda_driver_initialized():
|
|
1609
1755
|
self.saved_context = runtime.core.cuda_context_get_current()
|
|
1610
1756
|
|
|
1611
1757
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
1612
1758
|
if self.device.is_cuda:
|
|
1613
1759
|
runtime.core.cuda_context_pop_current()
|
|
1614
|
-
elif
|
|
1760
|
+
elif is_cuda_driver_initialized():
|
|
1615
1761
|
runtime.core.cuda_context_set_current(self.saved_context)
|
|
1616
1762
|
|
|
1617
1763
|
|
|
@@ -1896,7 +2042,7 @@ class Runtime:
|
|
|
1896
2042
|
|
|
1897
2043
|
self.core = self.load_dll(warp_lib)
|
|
1898
2044
|
|
|
1899
|
-
if
|
|
2045
|
+
if os.path.exists(llvm_lib):
|
|
1900
2046
|
self.llvm = self.load_dll(llvm_lib)
|
|
1901
2047
|
# setup c-types for warp-clang.dll
|
|
1902
2048
|
self.llvm.lookup.restype = ctypes.c_uint64
|
|
@@ -2262,6 +2408,8 @@ class Runtime:
|
|
|
2262
2408
|
self.core.cuda_driver_version.restype = ctypes.c_int
|
|
2263
2409
|
self.core.cuda_toolkit_version.argtypes = None
|
|
2264
2410
|
self.core.cuda_toolkit_version.restype = ctypes.c_int
|
|
2411
|
+
self.core.cuda_driver_is_initialized.argtypes = None
|
|
2412
|
+
self.core.cuda_driver_is_initialized.restype = ctypes.c_bool
|
|
2265
2413
|
|
|
2266
2414
|
self.core.nvrtc_supported_arch_count.argtypes = None
|
|
2267
2415
|
self.core.nvrtc_supported_arch_count.restype = ctypes.c_int
|
|
@@ -2364,6 +2512,7 @@ class Runtime:
|
|
|
2364
2512
|
ctypes.c_void_p,
|
|
2365
2513
|
ctypes.c_void_p,
|
|
2366
2514
|
ctypes.c_size_t,
|
|
2515
|
+
ctypes.c_int,
|
|
2367
2516
|
ctypes.POINTER(ctypes.c_void_p),
|
|
2368
2517
|
]
|
|
2369
2518
|
self.core.cuda_launch_kernel.restype = ctypes.c_size_t
|
|
@@ -2484,8 +2633,15 @@ class Runtime:
|
|
|
2484
2633
|
dll = ctypes.CDLL(dll_path, winmode=0)
|
|
2485
2634
|
else:
|
|
2486
2635
|
dll = ctypes.CDLL(dll_path)
|
|
2487
|
-
except OSError:
|
|
2488
|
-
|
|
2636
|
+
except OSError as e:
|
|
2637
|
+
if "GLIBCXX" in str(e):
|
|
2638
|
+
raise RuntimeError(
|
|
2639
|
+
f"Failed to load the shared library '{dll_path}'.\n"
|
|
2640
|
+
"The execution environment's libstdc++ runtime is older than the version the Warp library was built for.\n"
|
|
2641
|
+
"See https://nvidia.github.io/warp/_build/html/installation.html#conda-environments for details."
|
|
2642
|
+
) from e
|
|
2643
|
+
else:
|
|
2644
|
+
raise RuntimeError(f"Failed to load the shared library '{dll_path}'") from e
|
|
2489
2645
|
return dll
|
|
2490
2646
|
|
|
2491
2647
|
def get_device(self, ident: Devicelike = None) -> Device:
|
|
@@ -2614,6 +2770,21 @@ def is_device_available(device):
|
|
|
2614
2770
|
return device in get_devices()
|
|
2615
2771
|
|
|
2616
2772
|
|
|
2773
|
+
def is_cuda_driver_initialized() -> bool:
|
|
2774
|
+
"""Returns ``True`` if the CUDA driver is initialized.
|
|
2775
|
+
|
|
2776
|
+
This is a stricter test than ``is_cuda_available()`` since a CUDA driver
|
|
2777
|
+
call to ``cuCtxGetCurrent`` is made, and the result is compared to
|
|
2778
|
+
`CUDA_SUCCESS`. Note that `CUDA_SUCCESS` is returned by ``cuCtxGetCurrent``
|
|
2779
|
+
even if there is no context bound to the calling CPU thread.
|
|
2780
|
+
|
|
2781
|
+
This can be helpful in cases in which ``cuInit()`` was called before a fork.
|
|
2782
|
+
"""
|
|
2783
|
+
assert_initialized()
|
|
2784
|
+
|
|
2785
|
+
return runtime.core.cuda_driver_is_initialized()
|
|
2786
|
+
|
|
2787
|
+
|
|
2617
2788
|
def get_devices() -> List[Device]:
|
|
2618
2789
|
"""Returns a list of devices supported in this environment."""
|
|
2619
2790
|
|
|
@@ -3090,9 +3261,9 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
|
|
|
3090
3261
|
# - in forward passes, array types have to match
|
|
3091
3262
|
# - in backward passes, indexed array gradients are regular arrays
|
|
3092
3263
|
if adjoint:
|
|
3093
|
-
array_matches =
|
|
3264
|
+
array_matches = isinstance(value, warp.array)
|
|
3094
3265
|
else:
|
|
3095
|
-
array_matches = type(value)
|
|
3266
|
+
array_matches = type(value) is type(arg_type)
|
|
3096
3267
|
|
|
3097
3268
|
if not array_matches:
|
|
3098
3269
|
adj = "adjoint " if adjoint else ""
|
|
@@ -3172,7 +3343,7 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
|
|
|
3172
3343
|
# represents all data required for a kernel launch
|
|
3173
3344
|
# so that launches can be replayed quickly, use `wp.launch(..., record_cmd=True)`
|
|
3174
3345
|
class Launch:
|
|
3175
|
-
def __init__(self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None):
|
|
3346
|
+
def __init__(self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None, max_blocks=0):
|
|
3176
3347
|
# if not specified look up hooks
|
|
3177
3348
|
if not hooks:
|
|
3178
3349
|
module = kernel.module
|
|
@@ -3209,6 +3380,7 @@ class Launch:
|
|
|
3209
3380
|
self.params_addr = params_addr
|
|
3210
3381
|
self.device = device
|
|
3211
3382
|
self.bounds = bounds
|
|
3383
|
+
self.max_blocks = max_blocks
|
|
3212
3384
|
|
|
3213
3385
|
def set_dim(self, dim):
|
|
3214
3386
|
self.bounds = warp.types.launch_bounds_t(dim)
|
|
@@ -3274,7 +3446,9 @@ class Launch:
|
|
|
3274
3446
|
if self.device.is_cpu:
|
|
3275
3447
|
self.hooks.forward(*self.params)
|
|
3276
3448
|
else:
|
|
3277
|
-
runtime.core.cuda_launch_kernel(
|
|
3449
|
+
runtime.core.cuda_launch_kernel(
|
|
3450
|
+
self.device.context, self.hooks.forward, self.bounds.size, self.max_blocks, self.params_addr
|
|
3451
|
+
)
|
|
3278
3452
|
|
|
3279
3453
|
|
|
3280
3454
|
def launch(
|
|
@@ -3289,6 +3463,7 @@ def launch(
|
|
|
3289
3463
|
adjoint=False,
|
|
3290
3464
|
record_tape=True,
|
|
3291
3465
|
record_cmd=False,
|
|
3466
|
+
max_blocks=0,
|
|
3292
3467
|
):
|
|
3293
3468
|
"""Launch a Warp kernel on the target device
|
|
3294
3469
|
|
|
@@ -3306,6 +3481,8 @@ def launch(
|
|
|
3306
3481
|
adjoint: Whether to run forward or backward pass (typically use False)
|
|
3307
3482
|
record_tape: When true the launch will be recorded the global wp.Tape() object when present
|
|
3308
3483
|
record_cmd: When True the launch will be returned as a ``Launch`` command object, the launch will not occur until the user calls ``cmd.launch()``
|
|
3484
|
+
max_blocks: The maximum number of CUDA thread blocks to use. Only has an effect for CUDA kernel launches.
|
|
3485
|
+
If negative or zero, the maximum hardware value will be used.
|
|
3309
3486
|
"""
|
|
3310
3487
|
|
|
3311
3488
|
assert_initialized()
|
|
@@ -3317,7 +3494,7 @@ def launch(
|
|
|
3317
3494
|
device = runtime.get_device(device)
|
|
3318
3495
|
|
|
3319
3496
|
# check function is a Kernel
|
|
3320
|
-
if isinstance(kernel, Kernel)
|
|
3497
|
+
if not isinstance(kernel, Kernel):
|
|
3321
3498
|
raise RuntimeError("Error launching kernel, can only launch functions decorated with @wp.kernel.")
|
|
3322
3499
|
|
|
3323
3500
|
# debugging aid
|
|
@@ -3399,7 +3576,9 @@ def launch(
|
|
|
3399
3576
|
f"Failed to find backward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
|
|
3400
3577
|
)
|
|
3401
3578
|
|
|
3402
|
-
runtime.core.cuda_launch_kernel(
|
|
3579
|
+
runtime.core.cuda_launch_kernel(
|
|
3580
|
+
device.context, hooks.backward, bounds.size, max_blocks, kernel_params
|
|
3581
|
+
)
|
|
3403
3582
|
|
|
3404
3583
|
else:
|
|
3405
3584
|
if hooks.forward is None:
|
|
@@ -3420,7 +3599,9 @@ def launch(
|
|
|
3420
3599
|
|
|
3421
3600
|
else:
|
|
3422
3601
|
# launch
|
|
3423
|
-
runtime.core.cuda_launch_kernel(
|
|
3602
|
+
runtime.core.cuda_launch_kernel(
|
|
3603
|
+
device.context, hooks.forward, bounds.size, max_blocks, kernel_params
|
|
3604
|
+
)
|
|
3424
3605
|
|
|
3425
3606
|
try:
|
|
3426
3607
|
runtime.verify_cuda_device(device)
|
|
@@ -3430,7 +3611,7 @@ def launch(
|
|
|
3430
3611
|
|
|
3431
3612
|
# record on tape if one is active
|
|
3432
3613
|
if runtime.tape and record_tape:
|
|
3433
|
-
runtime.tape.record_launch(kernel, dim, inputs, outputs, device)
|
|
3614
|
+
runtime.tape.record_launch(kernel, dim, max_blocks, inputs, outputs, device)
|
|
3434
3615
|
|
|
3435
3616
|
|
|
3436
3617
|
def synchronize():
|
|
@@ -3440,7 +3621,7 @@ def synchronize():
|
|
|
3440
3621
|
or memory copies have completed.
|
|
3441
3622
|
"""
|
|
3442
3623
|
|
|
3443
|
-
if
|
|
3624
|
+
if is_cuda_driver_initialized():
|
|
3444
3625
|
# save the original context to avoid side effects
|
|
3445
3626
|
saved_context = runtime.core.cuda_context_get_current()
|
|
3446
3627
|
|
|
@@ -3490,7 +3671,7 @@ def synchronize_stream(stream_or_device=None):
|
|
|
3490
3671
|
runtime.core.cuda_stream_synchronize(stream.device.context, stream.cuda_stream)
|
|
3491
3672
|
|
|
3492
3673
|
|
|
3493
|
-
def force_load(device: Union[Device, str] = None, modules: List[Module] = None):
|
|
3674
|
+
def force_load(device: Union[Device, str, List[Device], List[str]] = None, modules: List[Module] = None):
|
|
3494
3675
|
"""Force user-defined kernels to be compiled and loaded
|
|
3495
3676
|
|
|
3496
3677
|
Args:
|
|
@@ -3498,12 +3679,14 @@ def force_load(device: Union[Device, str] = None, modules: List[Module] = None):
|
|
|
3498
3679
|
modules: List of modules to load. If None, load all imported modules.
|
|
3499
3680
|
"""
|
|
3500
3681
|
|
|
3501
|
-
if
|
|
3682
|
+
if is_cuda_driver_initialized():
|
|
3502
3683
|
# save original context to avoid side effects
|
|
3503
3684
|
saved_context = runtime.core.cuda_context_get_current()
|
|
3504
3685
|
|
|
3505
3686
|
if device is None:
|
|
3506
3687
|
devices = get_devices()
|
|
3688
|
+
elif isinstance(device, list):
|
|
3689
|
+
devices = [get_device(device_item) for device_item in device]
|
|
3507
3690
|
else:
|
|
3508
3691
|
devices = [get_device(device)]
|
|
3509
3692
|
|
|
@@ -3595,7 +3778,7 @@ def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
|
|
|
3595
3778
|
return get_module(m.__name__).options
|
|
3596
3779
|
|
|
3597
3780
|
|
|
3598
|
-
def capture_begin(device: Devicelike = None, stream=None, force_module_load=
|
|
3781
|
+
def capture_begin(device: Devicelike = None, stream=None, force_module_load=None):
|
|
3599
3782
|
"""Begin capture of a CUDA graph
|
|
3600
3783
|
|
|
3601
3784
|
Captures all subsequent kernel launches and memory operations on CUDA devices.
|
|
@@ -3609,7 +3792,10 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=True
|
|
|
3609
3792
|
|
|
3610
3793
|
"""
|
|
3611
3794
|
|
|
3612
|
-
if
|
|
3795
|
+
if force_module_load is None:
|
|
3796
|
+
force_module_load = warp.config.graph_capture_module_load_default
|
|
3797
|
+
|
|
3798
|
+
if warp.config.verify_cuda:
|
|
3613
3799
|
raise RuntimeError("Cannot use CUDA error verification during graph capture")
|
|
3614
3800
|
|
|
3615
3801
|
if stream is not None:
|
|
@@ -3624,6 +3810,9 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=True
|
|
|
3624
3810
|
|
|
3625
3811
|
device.is_capturing = True
|
|
3626
3812
|
|
|
3813
|
+
# disable garbage collection to avoid older allocations getting collected during graph capture
|
|
3814
|
+
gc.disable()
|
|
3815
|
+
|
|
3627
3816
|
with warp.ScopedStream(stream):
|
|
3628
3817
|
runtime.core.cuda_graph_begin_capture(device.context)
|
|
3629
3818
|
|
|
@@ -3647,6 +3836,9 @@ def capture_end(device: Devicelike = None, stream=None) -> Graph:
|
|
|
3647
3836
|
|
|
3648
3837
|
device.is_capturing = False
|
|
3649
3838
|
|
|
3839
|
+
# re-enable GC
|
|
3840
|
+
gc.enable()
|
|
3841
|
+
|
|
3650
3842
|
if graph is None:
|
|
3651
3843
|
raise RuntimeError(
|
|
3652
3844
|
"Error occurred during CUDA graph capture. This could be due to an unintended allocation or CPU/GPU synchronization event."
|
|
@@ -3841,7 +4033,7 @@ def type_str(t):
|
|
|
3841
4033
|
return t.__name__
|
|
3842
4034
|
|
|
3843
4035
|
|
|
3844
|
-
def print_function(f, file, noentry=False):
|
|
4036
|
+
def print_function(f, file, noentry=False): # pragma: no cover
|
|
3845
4037
|
"""Writes a function definition to a file for use in reST documentation
|
|
3846
4038
|
|
|
3847
4039
|
Args:
|
|
@@ -3886,7 +4078,7 @@ def print_function(f, file, noentry=False):
|
|
|
3886
4078
|
return True
|
|
3887
4079
|
|
|
3888
4080
|
|
|
3889
|
-
def
|
|
4081
|
+
def export_functions_rst(file): # pragma: no cover
|
|
3890
4082
|
header = (
|
|
3891
4083
|
"..\n"
|
|
3892
4084
|
" Autogenerated File - Do not edit. Run build_docs.py to generate.\n"
|
|
@@ -3906,6 +4098,8 @@ def print_builtins(file):
|
|
|
3906
4098
|
|
|
3907
4099
|
for t in warp.types.scalar_types:
|
|
3908
4100
|
print(f".. class:: {t.__name__}", file=file)
|
|
4101
|
+
# Manually add wp.bool since it's inconvenient to add to wp.types.scalar_types:
|
|
4102
|
+
print(f".. class:: {warp.types.bool.__name__}", file=file)
|
|
3909
4103
|
|
|
3910
4104
|
print("\n\nVector Types", file=file)
|
|
3911
4105
|
print("------------", file=file)
|
|
@@ -3925,6 +4119,14 @@ def print_builtins(file):
|
|
|
3925
4119
|
print(".. class:: Transformation", file=file)
|
|
3926
4120
|
print(".. class:: Array", file=file)
|
|
3927
4121
|
|
|
4122
|
+
print("\nQuery Types", file=file)
|
|
4123
|
+
print("-------------", file=file)
|
|
4124
|
+
print(".. autoclass:: bvh_query_t", file=file)
|
|
4125
|
+
print(".. autoclass:: hash_grid_query_t", file=file)
|
|
4126
|
+
print(".. autoclass:: mesh_query_aabb_t", file=file)
|
|
4127
|
+
print(".. autoclass:: mesh_query_point_t", file=file)
|
|
4128
|
+
print(".. autoclass:: mesh_query_ray_t", file=file)
|
|
4129
|
+
|
|
3928
4130
|
# build dictionary of all functions by group
|
|
3929
4131
|
groups = {}
|
|
3930
4132
|
|
|
@@ -3958,7 +4160,7 @@ def print_builtins(file):
|
|
|
3958
4160
|
print(".. [1] Note: function gradients not implemented for backpropagation.", file=file)
|
|
3959
4161
|
|
|
3960
4162
|
|
|
3961
|
-
def export_stubs(file):
|
|
4163
|
+
def export_stubs(file): # pragma: no cover
|
|
3962
4164
|
"""Generates stub file for auto-complete of builtin functions"""
|
|
3963
4165
|
|
|
3964
4166
|
import textwrap
|
|
@@ -3990,6 +4192,8 @@ def export_stubs(file):
|
|
|
3990
4192
|
print("Quaternion = Generic[Float]", file=file)
|
|
3991
4193
|
print("Transformation = Generic[Float]", file=file)
|
|
3992
4194
|
print("Array = Generic[DType]", file=file)
|
|
4195
|
+
print("FabricArray = Generic[DType]", file=file)
|
|
4196
|
+
print("IndexedFabricArray = Generic[DType]", file=file)
|
|
3993
4197
|
|
|
3994
4198
|
# prepend __init__.py
|
|
3995
4199
|
with open(os.path.join(os.path.dirname(file.name), "__init__.py")) as header_file:
|
|
@@ -4006,7 +4210,7 @@ def export_stubs(file):
|
|
|
4006
4210
|
|
|
4007
4211
|
return_str = ""
|
|
4008
4212
|
|
|
4009
|
-
if f.export
|
|
4213
|
+
if not f.export or f.hidden: # or f.generic:
|
|
4010
4214
|
continue
|
|
4011
4215
|
|
|
4012
4216
|
try:
|
|
@@ -4027,8 +4231,18 @@ def export_stubs(file):
|
|
|
4027
4231
|
print(" ...\n\n", file=file)
|
|
4028
4232
|
|
|
4029
4233
|
|
|
4030
|
-
def export_builtins(file):
|
|
4031
|
-
def
|
|
4234
|
+
def export_builtins(file: io.TextIOBase): # pragma: no cover
|
|
4235
|
+
def ctype_arg_str(t):
|
|
4236
|
+
if isinstance(t, int):
|
|
4237
|
+
return "int"
|
|
4238
|
+
elif isinstance(t, float):
|
|
4239
|
+
return "float"
|
|
4240
|
+
elif t in warp.types.vector_types:
|
|
4241
|
+
return f"{t.__name__}&"
|
|
4242
|
+
else:
|
|
4243
|
+
return t.__name__
|
|
4244
|
+
|
|
4245
|
+
def ctype_ret_str(t):
|
|
4032
4246
|
if isinstance(t, int):
|
|
4033
4247
|
return "int"
|
|
4034
4248
|
elif isinstance(t, float):
|
|
@@ -4036,9 +4250,12 @@ def export_builtins(file):
|
|
|
4036
4250
|
else:
|
|
4037
4251
|
return t.__name__
|
|
4038
4252
|
|
|
4253
|
+
file.write("namespace wp {\n\n")
|
|
4254
|
+
file.write('extern "C" {\n\n')
|
|
4255
|
+
|
|
4039
4256
|
for k, g in builtin_functions.items():
|
|
4040
4257
|
for f in g.overloads:
|
|
4041
|
-
if f.export
|
|
4258
|
+
if not f.export or f.generic:
|
|
4042
4259
|
continue
|
|
4043
4260
|
|
|
4044
4261
|
simple = True
|
|
@@ -4052,7 +4269,7 @@ def export_builtins(file):
|
|
|
4052
4269
|
if not simple or f.variadic:
|
|
4053
4270
|
continue
|
|
4054
4271
|
|
|
4055
|
-
args = ", ".join(f"{
|
|
4272
|
+
args = ", ".join(f"{ctype_arg_str(v)} {k}" for k, v in f.input_types.items())
|
|
4056
4273
|
params = ", ".join(f.input_types.keys())
|
|
4057
4274
|
|
|
4058
4275
|
return_type = ""
|
|
@@ -4060,7 +4277,7 @@ def export_builtins(file):
|
|
|
4060
4277
|
try:
|
|
4061
4278
|
# todo: construct a default value for each of the functions args
|
|
4062
4279
|
# so we can generate the return type for overloaded functions
|
|
4063
|
-
return_type =
|
|
4280
|
+
return_type = ctype_ret_str(f.value_func(None, None, None))
|
|
4064
4281
|
except Exception:
|
|
4065
4282
|
continue
|
|
4066
4283
|
|
|
@@ -4068,17 +4285,17 @@ def export_builtins(file):
|
|
|
4068
4285
|
continue
|
|
4069
4286
|
|
|
4070
4287
|
if args == "":
|
|
4071
|
-
|
|
4072
|
-
f"WP_API void {f.mangled_name}({return_type}* ret) {{ *ret = wp::{f.key}({params}); }}", file=file
|
|
4073
|
-
)
|
|
4288
|
+
file.write(f"WP_API void {f.mangled_name}({return_type}* ret) {{ *ret = wp::{f.key}({params}); }}\n")
|
|
4074
4289
|
elif return_type == "None":
|
|
4075
|
-
|
|
4290
|
+
file.write(f"WP_API void {f.mangled_name}({args}) {{ wp::{f.key}({params}); }}\n")
|
|
4076
4291
|
else:
|
|
4077
|
-
|
|
4078
|
-
f"WP_API void {f.mangled_name}({args}, {return_type}* ret) {{ *ret = wp::{f.key}({params}); }}"
|
|
4079
|
-
file=file,
|
|
4292
|
+
file.write(
|
|
4293
|
+
f"WP_API void {f.mangled_name}({args}, {return_type}* ret) {{ *ret = wp::{f.key}({params}); }}\n"
|
|
4080
4294
|
)
|
|
4081
4295
|
|
|
4296
|
+
file.write('\n} // extern "C"\n\n')
|
|
4297
|
+
file.write("} // namespace wp\n")
|
|
4298
|
+
|
|
4082
4299
|
|
|
4083
4300
|
# initialize global runtime
|
|
4084
4301
|
runtime = None
|