warp-lang 1.0.0b5__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/conf.py +3 -4
- examples/env/env_ant.py +1 -1
- examples/env/env_cartpole.py +1 -1
- examples/env/env_humanoid.py +1 -1
- examples/example_dem.py +28 -26
- examples/example_diffray.py +37 -30
- examples/example_fluid.py +7 -3
- examples/example_jacobian_ik.py +1 -1
- examples/example_mesh_intersect.py +10 -7
- examples/example_nvdb.py +3 -3
- examples/example_render_opengl.py +19 -10
- examples/example_sim_cartpole.py +9 -5
- examples/example_sim_cloth.py +29 -25
- examples/example_sim_fk_grad.py +2 -2
- examples/example_sim_fk_grad_torch.py +3 -3
- examples/example_sim_grad_bounce.py +11 -8
- examples/example_sim_grad_cloth.py +12 -9
- examples/example_sim_granular.py +2 -2
- examples/example_sim_granular_collision_sdf.py +13 -13
- examples/example_sim_neo_hookean.py +3 -3
- examples/example_sim_particle_chain.py +2 -2
- examples/example_sim_quadruped.py +8 -5
- examples/example_sim_rigid_chain.py +8 -5
- examples/example_sim_rigid_contact.py +13 -10
- examples/example_sim_rigid_fem.py +2 -2
- examples/example_sim_rigid_gyroscopic.py +2 -2
- examples/example_sim_rigid_kinematics.py +1 -1
- examples/example_sim_trajopt.py +3 -2
- examples/fem/example_apic_fluid.py +5 -7
- examples/fem/example_diffusion_mgpu.py +18 -16
- warp/__init__.py +3 -2
- warp/bin/warp.so +0 -0
- warp/build_dll.py +29 -9
- warp/builtins.py +206 -7
- warp/codegen.py +58 -38
- warp/config.py +3 -1
- warp/context.py +234 -128
- warp/fem/__init__.py +2 -2
- warp/fem/cache.py +2 -1
- warp/fem/field/nodal_field.py +18 -17
- warp/fem/geometry/hexmesh.py +11 -6
- warp/fem/geometry/quadmesh_2d.py +16 -12
- warp/fem/geometry/tetmesh.py +19 -8
- warp/fem/geometry/trimesh_2d.py +18 -7
- warp/fem/integrate.py +341 -196
- warp/fem/quadrature/__init__.py +1 -1
- warp/fem/quadrature/pic_quadrature.py +138 -53
- warp/fem/quadrature/quadrature.py +81 -9
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_space.py +169 -51
- warp/fem/space/grid_2d_function_space.py +2 -2
- warp/fem/space/grid_3d_function_space.py +2 -2
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +9 -6
- warp/fem/space/quadmesh_2d_function_space.py +2 -2
- warp/fem/space/shape/cube_shape_function.py +27 -15
- warp/fem/space/shape/square_shape_function.py +29 -18
- warp/fem/space/tetmesh_function_space.py +2 -2
- warp/fem/space/topology.py +10 -0
- warp/fem/space/trimesh_2d_function_space.py +2 -2
- warp/fem/utils.py +10 -5
- warp/native/array.h +49 -8
- warp/native/builtin.h +31 -14
- warp/native/cuda_util.cpp +8 -3
- warp/native/cuda_util.h +1 -0
- warp/native/exports.h +1177 -1108
- warp/native/intersect.h +4 -4
- warp/native/intersect_adj.h +8 -8
- warp/native/mat.h +65 -6
- warp/native/mesh.h +126 -5
- warp/native/quat.h +28 -4
- warp/native/vec.h +76 -14
- warp/native/warp.cu +1 -6
- warp/render/render_opengl.py +261 -109
- warp/sim/import_mjcf.py +13 -7
- warp/sim/import_urdf.py +14 -14
- warp/sim/inertia.py +17 -18
- warp/sim/model.py +67 -67
- warp/sim/render.py +1 -1
- warp/sparse.py +6 -6
- warp/stubs.py +19 -81
- warp/tape.py +1 -1
- warp/tests/__main__.py +3 -6
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/{test_kinematics.py → disabled_kinematics.py} +10 -12
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +102 -106
- warp/tests/test_arithmetic.py +39 -40
- warp/tests/test_array.py +46 -48
- warp/tests/test_array_reduce.py +25 -19
- warp/tests/test_atomic.py +62 -26
- warp/tests/test_bool.py +16 -11
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +9 -12
- warp/tests/test_closest_point_edge_edge.py +53 -57
- warp/tests/test_codegen.py +164 -134
- warp/tests/test_compile_consts.py +13 -19
- warp/tests/test_conditional.py +30 -32
- warp/tests/test_copy.py +9 -12
- warp/tests/test_ctypes.py +90 -98
- warp/tests/test_dense.py +20 -14
- warp/tests/test_devices.py +34 -35
- warp/tests/test_dlpack.py +74 -75
- warp/tests/test_examples.py +215 -97
- warp/tests/test_fabricarray.py +15 -21
- warp/tests/test_fast_math.py +14 -11
- warp/tests/test_fem.py +280 -97
- warp/tests/test_fp16.py +19 -15
- warp/tests/test_func.py +177 -194
- warp/tests/test_generics.py +71 -77
- warp/tests/test_grad.py +83 -32
- warp/tests/test_grad_customs.py +7 -9
- warp/tests/test_hash_grid.py +6 -10
- warp/tests/test_import.py +9 -23
- warp/tests/test_indexedarray.py +19 -21
- warp/tests/test_intersect.py +15 -9
- warp/tests/test_large.py +17 -19
- warp/tests/test_launch.py +14 -17
- warp/tests/test_lerp.py +63 -63
- warp/tests/test_lvalue.py +84 -35
- warp/tests/test_marching_cubes.py +9 -13
- warp/tests/test_mat.py +388 -3004
- warp/tests/test_mat_lite.py +9 -12
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +10 -11
- warp/tests/test_matmul.py +104 -100
- warp/tests/test_matmul_lite.py +72 -98
- warp/tests/test_mesh.py +35 -32
- warp/tests/test_mesh_query_aabb.py +18 -25
- warp/tests/test_mesh_query_point.py +39 -23
- warp/tests/test_mesh_query_ray.py +9 -21
- warp/tests/test_mlp.py +8 -9
- warp/tests/test_model.py +89 -93
- warp/tests/test_modules_lite.py +15 -25
- warp/tests/test_multigpu.py +87 -114
- warp/tests/test_noise.py +10 -12
- warp/tests/test_operators.py +14 -21
- warp/tests/test_options.py +10 -11
- warp/tests/test_pinned.py +16 -18
- warp/tests/test_print.py +16 -20
- warp/tests/test_quat.py +121 -88
- warp/tests/test_rand.py +12 -13
- warp/tests/test_reload.py +27 -32
- warp/tests/test_rounding.py +7 -10
- warp/tests/test_runlength_encode.py +105 -106
- warp/tests/test_smoothstep.py +8 -9
- warp/tests/test_snippet.py +13 -22
- warp/tests/test_sparse.py +30 -29
- warp/tests/test_spatial.py +179 -174
- warp/tests/test_streams.py +100 -107
- warp/tests/test_struct.py +98 -67
- warp/tests/test_tape.py +11 -17
- warp/tests/test_torch.py +89 -86
- warp/tests/test_transient_module.py +9 -12
- warp/tests/test_types.py +328 -50
- warp/tests/test_utils.py +217 -218
- warp/tests/test_vec.py +133 -2133
- warp/tests/test_vec_lite.py +8 -11
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +391 -382
- warp/tests/test_volume_write.py +122 -135
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +291 -0
- warp/tests/{test_base.py → unittest_utils.py} +138 -25
- warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
- warp/tests/{test_debug.py → walkthough_debug.py} +2 -15
- warp/thirdparty/unittest_parallel.py +257 -54
- warp/types.py +119 -98
- warp/utils.py +14 -0
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/METADATA +2 -1
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/RECORD +182 -178
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
- warp/tests/test_all.py +0 -239
- warp/tests/test_conditional_unequal_types_kernels.py +0 -14
- warp/tests/test_coverage.py +0 -38
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/context.py
CHANGED
|
@@ -79,6 +79,7 @@ class Function:
|
|
|
79
79
|
overloaded_annotations=None,
|
|
80
80
|
code_transformers=[],
|
|
81
81
|
skip_adding_overload=False,
|
|
82
|
+
require_original_output_arg=False,
|
|
82
83
|
):
|
|
83
84
|
self.func = func # points to Python function decorated with @wp.func, may be None for builtins
|
|
84
85
|
self.key = key
|
|
@@ -97,6 +98,7 @@ class Function:
|
|
|
97
98
|
self.native_snippet = native_snippet
|
|
98
99
|
self.adj_native_snippet = adj_native_snippet
|
|
99
100
|
self.custom_grad_func = None
|
|
101
|
+
self.require_original_output_arg = require_original_output_arg
|
|
100
102
|
|
|
101
103
|
if initializer_list_func is None:
|
|
102
104
|
self.initializer_list_func = lambda x, y: False
|
|
@@ -176,112 +178,16 @@ class Function:
|
|
|
176
178
|
# from within a kernel (experimental).
|
|
177
179
|
|
|
178
180
|
if self.is_builtin() and self.mangled_name:
|
|
179
|
-
|
|
180
|
-
|
|
181
|
+
# For each of this function's existing overloads, we attempt to pack
|
|
182
|
+
# the given arguments into the C types expected by the corresponding
|
|
183
|
+
# parameters, and we rinse and repeat until we get a match.
|
|
184
|
+
for overload in self.overloads:
|
|
185
|
+
if overload.generic:
|
|
181
186
|
continue
|
|
182
187
|
|
|
183
|
-
|
|
184
|
-
if
|
|
185
|
-
|
|
186
|
-
f"Couldn't find function {self.key} with mangled name {f.mangled_name} in the Warp native library"
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
try:
|
|
190
|
-
# try and pack args into what the function expects
|
|
191
|
-
params = []
|
|
192
|
-
for i, (arg_name, arg_type) in enumerate(f.input_types.items()):
|
|
193
|
-
a = args[i]
|
|
194
|
-
|
|
195
|
-
# try to convert to a value type (vec3, mat33, etc)
|
|
196
|
-
if issubclass(arg_type, ctypes.Array):
|
|
197
|
-
# wrap the arg_type (which is an ctypes.Array) in a structure
|
|
198
|
-
# to ensure parameter is passed to the .dll by value rather than reference
|
|
199
|
-
class ValueArg(ctypes.Structure):
|
|
200
|
-
_fields_ = [("value", arg_type)]
|
|
201
|
-
|
|
202
|
-
x = ValueArg()
|
|
203
|
-
|
|
204
|
-
# force conversion to ndarray first (handles tuple / list, Gf.Vec3 case)
|
|
205
|
-
if isinstance(a, ctypes.Array) is False:
|
|
206
|
-
# assume you want the float32 version of the function so it doesn't just
|
|
207
|
-
# grab an override for a random data type:
|
|
208
|
-
if arg_type._type_ != ctypes.c_float:
|
|
209
|
-
raise RuntimeError(
|
|
210
|
-
f"Error calling function '{f.key}', parameter for argument '{arg_name}' does not have c_float type."
|
|
211
|
-
)
|
|
212
|
-
|
|
213
|
-
a = np.array(a)
|
|
214
|
-
|
|
215
|
-
# flatten to 1D array
|
|
216
|
-
v = a.flatten()
|
|
217
|
-
if len(v) != arg_type._length_:
|
|
218
|
-
raise RuntimeError(
|
|
219
|
-
f"Error calling function '{f.key}', parameter for argument '{arg_name}' has length {len(v)}, but expected {arg_type._length_}. Could not convert parameter to {arg_type}."
|
|
220
|
-
)
|
|
221
|
-
|
|
222
|
-
for i in range(arg_type._length_):
|
|
223
|
-
x.value[i] = v[i]
|
|
224
|
-
|
|
225
|
-
else:
|
|
226
|
-
# already a built-in type, check it matches
|
|
227
|
-
if not warp.types.types_equal(type(a), arg_type):
|
|
228
|
-
raise RuntimeError(
|
|
229
|
-
f"Error calling function '{f.key}', parameter for argument '{arg_name}' has type '{type(a)}' but expected '{arg_type}'"
|
|
230
|
-
)
|
|
231
|
-
|
|
232
|
-
if isinstance(a, arg_type):
|
|
233
|
-
x.value = a
|
|
234
|
-
else:
|
|
235
|
-
# Cast the value to its argument type to make sure that it can be assigned to the field of the `ValueArg` struct.
|
|
236
|
-
# This could error otherwise when, for example, the field type is set to `vec3i` while the value is of type
|
|
237
|
-
# `vector(length=3, dtype=int)`, even though both types are semantically identical.
|
|
238
|
-
x.value = arg_type(a)
|
|
239
|
-
|
|
240
|
-
params.append(x)
|
|
241
|
-
|
|
242
|
-
else:
|
|
243
|
-
try:
|
|
244
|
-
# try to pack as a scalar type
|
|
245
|
-
params.append(arg_type._type_(a))
|
|
246
|
-
except Exception:
|
|
247
|
-
raise RuntimeError(
|
|
248
|
-
f"Error calling function {f.key}, unable to pack function parameter type {type(a)} for param {arg_name}, expected {arg_type}"
|
|
249
|
-
)
|
|
250
|
-
|
|
251
|
-
# returns the corresponding ctype for a scalar or vector warp type
|
|
252
|
-
def type_ctype(dtype):
|
|
253
|
-
if dtype == float:
|
|
254
|
-
return ctypes.c_float
|
|
255
|
-
elif dtype == int:
|
|
256
|
-
return ctypes.c_int32
|
|
257
|
-
elif issubclass(dtype, ctypes.Array):
|
|
258
|
-
return dtype
|
|
259
|
-
elif issubclass(dtype, ctypes.Structure):
|
|
260
|
-
return dtype
|
|
261
|
-
else:
|
|
262
|
-
# scalar type
|
|
263
|
-
return dtype._type_
|
|
264
|
-
|
|
265
|
-
value_type = type_ctype(f.value_func(None, None, None))
|
|
266
|
-
|
|
267
|
-
# construct return value (passed by address)
|
|
268
|
-
ret = value_type()
|
|
269
|
-
ret_addr = ctypes.c_void_p(ctypes.addressof(ret))
|
|
270
|
-
|
|
271
|
-
params.append(ret_addr)
|
|
272
|
-
|
|
273
|
-
c_func = getattr(warp.context.runtime.core, f.mangled_name)
|
|
274
|
-
c_func(*params)
|
|
275
|
-
|
|
276
|
-
if issubclass(value_type, ctypes.Array) or issubclass(value_type, ctypes.Structure):
|
|
277
|
-
# return vector types as ctypes
|
|
278
|
-
return ret
|
|
279
|
-
|
|
280
|
-
# return scalar types as int/float
|
|
281
|
-
return ret.value
|
|
282
|
-
except Exception:
|
|
283
|
-
# couldn't pack values to match this overload
|
|
284
|
-
continue
|
|
188
|
+
success, return_value = call_builtin(overload, *args)
|
|
189
|
+
if success:
|
|
190
|
+
return return_value
|
|
285
191
|
|
|
286
192
|
# overload resolution or call failed
|
|
287
193
|
raise RuntimeError(
|
|
@@ -289,7 +195,7 @@ class Function:
|
|
|
289
195
|
f"the arguments '{', '.join(type(x).__name__ for x in args)}'"
|
|
290
196
|
)
|
|
291
197
|
|
|
292
|
-
|
|
198
|
+
if hasattr(self, "user_overloads") and len(self.user_overloads):
|
|
293
199
|
# user-defined function with overloads
|
|
294
200
|
|
|
295
201
|
if len(kwargs):
|
|
@@ -298,28 +204,26 @@ class Function:
|
|
|
298
204
|
)
|
|
299
205
|
|
|
300
206
|
# try and find a matching overload
|
|
301
|
-
for
|
|
302
|
-
if len(
|
|
207
|
+
for overload in self.user_overloads.values():
|
|
208
|
+
if len(overload.input_types) != len(args):
|
|
303
209
|
continue
|
|
304
|
-
template_types = list(
|
|
305
|
-
arg_names = list(
|
|
210
|
+
template_types = list(overload.input_types.values())
|
|
211
|
+
arg_names = list(overload.input_types.keys())
|
|
306
212
|
try:
|
|
307
213
|
# attempt to unify argument types with function template types
|
|
308
214
|
warp.types.infer_argument_types(args, template_types, arg_names)
|
|
309
|
-
return
|
|
215
|
+
return overload.func(*args)
|
|
310
216
|
except Exception:
|
|
311
217
|
continue
|
|
312
218
|
|
|
313
219
|
raise RuntimeError(f"Error calling function '{self.key}', no overload found for arguments {args}")
|
|
314
220
|
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
if self.func is None:
|
|
319
|
-
raise RuntimeError(f"Error calling function '{self.key}', function is undefined")
|
|
221
|
+
# user-defined function with no overloads
|
|
222
|
+
if self.func is None:
|
|
223
|
+
raise RuntimeError(f"Error calling function '{self.key}', function is undefined")
|
|
320
224
|
|
|
321
|
-
|
|
322
|
-
|
|
225
|
+
# this function has no overloads, call it like a plain Python function
|
|
226
|
+
return self.func(*args, **kwargs)
|
|
323
227
|
|
|
324
228
|
def is_builtin(self):
|
|
325
229
|
return self.func is None
|
|
@@ -436,6 +340,184 @@ class Function:
|
|
|
436
340
|
return f"<Function {self.key}({inputs_str})>"
|
|
437
341
|
|
|
438
342
|
|
|
343
|
+
def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
|
|
344
|
+
uses_non_warp_array_type = False
|
|
345
|
+
|
|
346
|
+
# Retrieve the built-in function from Warp's dll.
|
|
347
|
+
c_func = getattr(warp.context.runtime.core, func.mangled_name)
|
|
348
|
+
|
|
349
|
+
# Try gathering the parameters that the function expects and pack them
|
|
350
|
+
# into their corresponding C types.
|
|
351
|
+
c_params = []
|
|
352
|
+
for i, (_, arg_type) in enumerate(func.input_types.items()):
|
|
353
|
+
param = params[i]
|
|
354
|
+
|
|
355
|
+
try:
|
|
356
|
+
iter(param)
|
|
357
|
+
except TypeError:
|
|
358
|
+
is_array = False
|
|
359
|
+
else:
|
|
360
|
+
is_array = True
|
|
361
|
+
|
|
362
|
+
if is_array:
|
|
363
|
+
if not issubclass(arg_type, ctypes.Array):
|
|
364
|
+
return (False, None)
|
|
365
|
+
|
|
366
|
+
# The argument expects a built-in Warp type like a vector or a matrix.
|
|
367
|
+
|
|
368
|
+
c_param = None
|
|
369
|
+
|
|
370
|
+
if isinstance(param, ctypes.Array):
|
|
371
|
+
# The given parameter is also a built-in Warp type, so we only need
|
|
372
|
+
# to make sure that it matches with the argument.
|
|
373
|
+
if not warp.types.types_equal(type(param), arg_type):
|
|
374
|
+
return (False, None)
|
|
375
|
+
|
|
376
|
+
if isinstance(param, arg_type):
|
|
377
|
+
c_param = param
|
|
378
|
+
else:
|
|
379
|
+
# Cast the value to its argument type to make sure that it
|
|
380
|
+
# can be assigned to the field of the `Param` struct.
|
|
381
|
+
# This could error otherwise when, for example, the field type
|
|
382
|
+
# is set to `vec3i` while the value is of type `vector(length=3, dtype=int)`,
|
|
383
|
+
# even though both types are semantically identical.
|
|
384
|
+
c_param = arg_type(param)
|
|
385
|
+
else:
|
|
386
|
+
# Flatten the parameter values into a flat 1-D array.
|
|
387
|
+
arr = []
|
|
388
|
+
ndim = 1
|
|
389
|
+
stack = [(0, param)]
|
|
390
|
+
while stack:
|
|
391
|
+
depth, elem = stack.pop(0)
|
|
392
|
+
try:
|
|
393
|
+
# If `elem` is a sequence, then it should be possible
|
|
394
|
+
# to add its elements to the stack for later processing.
|
|
395
|
+
stack.extend((depth + 1, x) for x in elem)
|
|
396
|
+
except TypeError:
|
|
397
|
+
# Since `elem` doesn't seem to be a sequence,
|
|
398
|
+
# we must have a leaf value that we need to add to our
|
|
399
|
+
# resulting array.
|
|
400
|
+
arr.append(elem)
|
|
401
|
+
ndim = max(depth, ndim)
|
|
402
|
+
|
|
403
|
+
assert ndim > 0
|
|
404
|
+
|
|
405
|
+
# Ensure that if the given parameter value is, say, a 2-D array,
|
|
406
|
+
# then we try to resolve it against a matrix argument rather than
|
|
407
|
+
# a vector.
|
|
408
|
+
if ndim > len(arg_type._shape_):
|
|
409
|
+
return (False, None)
|
|
410
|
+
|
|
411
|
+
elem_count = len(arr)
|
|
412
|
+
if elem_count != arg_type._length_:
|
|
413
|
+
return (False, None)
|
|
414
|
+
|
|
415
|
+
# Retrieve the element type of the sequence while ensuring
|
|
416
|
+
# that it's homogeneous.
|
|
417
|
+
elem_type = type(arr[0])
|
|
418
|
+
for i in range(1, elem_count):
|
|
419
|
+
if type(arr[i]) is not elem_type:
|
|
420
|
+
raise ValueError("All array elements must share the same type.")
|
|
421
|
+
|
|
422
|
+
expected_elem_type = arg_type._wp_scalar_type_
|
|
423
|
+
if not (
|
|
424
|
+
elem_type is expected_elem_type
|
|
425
|
+
or (elem_type is float and expected_elem_type is warp.types.float32)
|
|
426
|
+
or (elem_type is int and expected_elem_type is warp.types.int32)
|
|
427
|
+
or (
|
|
428
|
+
issubclass(elem_type, np.number)
|
|
429
|
+
and warp.types.np_dtype_to_warp_type[np.dtype(elem_type)] is expected_elem_type
|
|
430
|
+
)
|
|
431
|
+
):
|
|
432
|
+
# The parameter value has a type not matching the type defined
|
|
433
|
+
# for the corresponding argument.
|
|
434
|
+
return (False, None)
|
|
435
|
+
|
|
436
|
+
if elem_type in warp.types.int_types:
|
|
437
|
+
# Pass the value through the expected integer type
|
|
438
|
+
# in order to evaluate any integer wrapping.
|
|
439
|
+
# For example `uint8(-1)` should result in the value `-255`.
|
|
440
|
+
arr = tuple(elem_type._type_(x.value).value for x in arr)
|
|
441
|
+
elif elem_type in warp.types.float_types:
|
|
442
|
+
# Extract the floating-point values.
|
|
443
|
+
arr = tuple(x.value for x in arr)
|
|
444
|
+
|
|
445
|
+
c_param = arg_type()
|
|
446
|
+
if warp.types.type_is_matrix(arg_type):
|
|
447
|
+
rows, cols = arg_type._shape_
|
|
448
|
+
for i in range(rows):
|
|
449
|
+
idx_start = i * cols
|
|
450
|
+
idx_end = idx_start + cols
|
|
451
|
+
c_param[i] = arr[idx_start:idx_end]
|
|
452
|
+
else:
|
|
453
|
+
c_param[:] = arr
|
|
454
|
+
|
|
455
|
+
uses_non_warp_array_type = True
|
|
456
|
+
|
|
457
|
+
c_params.append(ctypes.byref(c_param))
|
|
458
|
+
else:
|
|
459
|
+
if issubclass(arg_type, ctypes.Array):
|
|
460
|
+
return (False, None)
|
|
461
|
+
|
|
462
|
+
if not (
|
|
463
|
+
isinstance(param, arg_type)
|
|
464
|
+
or (type(param) is float and arg_type is warp.types.float32)
|
|
465
|
+
or (type(param) is int and arg_type is warp.types.int32)
|
|
466
|
+
or warp.types.np_dtype_to_warp_type.get(getattr(param, "dtype", None)) is arg_type
|
|
467
|
+
):
|
|
468
|
+
return (False, None)
|
|
469
|
+
|
|
470
|
+
if type(param) in warp.types.scalar_types:
|
|
471
|
+
param = param.value
|
|
472
|
+
|
|
473
|
+
# try to pack as a scalar type
|
|
474
|
+
if arg_type == warp.types.float16:
|
|
475
|
+
c_params.append(arg_type._type_(warp.types.float_to_half_bits(param)))
|
|
476
|
+
else:
|
|
477
|
+
c_params.append(arg_type._type_(param))
|
|
478
|
+
|
|
479
|
+
# returns the corresponding ctype for a scalar or vector warp type
|
|
480
|
+
value_type = func.value_func(None, None, None)
|
|
481
|
+
if value_type == float:
|
|
482
|
+
value_ctype = ctypes.c_float
|
|
483
|
+
elif value_type == int:
|
|
484
|
+
value_ctype = ctypes.c_int32
|
|
485
|
+
elif issubclass(value_type, (ctypes.Array, ctypes.Structure)):
|
|
486
|
+
value_ctype = value_type
|
|
487
|
+
else:
|
|
488
|
+
# scalar type
|
|
489
|
+
value_ctype = value_type._type_
|
|
490
|
+
|
|
491
|
+
# construct return value (passed by address)
|
|
492
|
+
ret = value_ctype()
|
|
493
|
+
ret_addr = ctypes.c_void_p(ctypes.addressof(ret))
|
|
494
|
+
c_params.append(ret_addr)
|
|
495
|
+
|
|
496
|
+
# Call the built-in function from Warp's dll.
|
|
497
|
+
c_func(*c_params)
|
|
498
|
+
|
|
499
|
+
# TODO: uncomment when we have a way to print warning messages only once.
|
|
500
|
+
# if uses_non_warp_array_type:
|
|
501
|
+
# warp.utils.warn(
|
|
502
|
+
# "Support for built-in functions called with non-Warp array types, "
|
|
503
|
+
# "such as lists, tuples, NumPy arrays, and others, will be dropped "
|
|
504
|
+
# "in the future. Use a Warp type such as `wp.vec`, `wp.mat`, "
|
|
505
|
+
# "`wp.quat`, or `wp.transform`.",
|
|
506
|
+
# DeprecationWarning,
|
|
507
|
+
# stacklevel=3
|
|
508
|
+
# )
|
|
509
|
+
|
|
510
|
+
if issubclass(value_ctype, ctypes.Array) or issubclass(value_ctype, ctypes.Structure):
|
|
511
|
+
# return vector types as ctypes
|
|
512
|
+
return (True, ret)
|
|
513
|
+
|
|
514
|
+
if value_type == warp.types.float16:
|
|
515
|
+
return (True, warp.types.half_bits_to_float(ret.value))
|
|
516
|
+
|
|
517
|
+
# return scalar types as int/float
|
|
518
|
+
return (True, ret.value)
|
|
519
|
+
|
|
520
|
+
|
|
439
521
|
class KernelHooks:
|
|
440
522
|
def __init__(self, forward, backward):
|
|
441
523
|
self.forward = forward
|
|
@@ -852,6 +934,7 @@ def add_builtin(
|
|
|
852
934
|
missing_grad=False,
|
|
853
935
|
native_func=None,
|
|
854
936
|
defaults=None,
|
|
937
|
+
require_original_output_arg=False,
|
|
855
938
|
):
|
|
856
939
|
# wrap simple single-type functions with a value_func()
|
|
857
940
|
if value_func is None:
|
|
@@ -976,6 +1059,7 @@ def add_builtin(
|
|
|
976
1059
|
hidden=True,
|
|
977
1060
|
skip_replay=skip_replay,
|
|
978
1061
|
missing_grad=missing_grad,
|
|
1062
|
+
require_original_output_arg=require_original_output_arg,
|
|
979
1063
|
)
|
|
980
1064
|
|
|
981
1065
|
func = Function(
|
|
@@ -996,6 +1080,7 @@ def add_builtin(
|
|
|
996
1080
|
generic=generic,
|
|
997
1081
|
native_func=native_func,
|
|
998
1082
|
defaults=defaults,
|
|
1083
|
+
require_original_output_arg=require_original_output_arg,
|
|
999
1084
|
)
|
|
1000
1085
|
|
|
1001
1086
|
if key in builtin_functions:
|
|
@@ -1005,7 +1090,7 @@ def add_builtin(
|
|
|
1005
1090
|
|
|
1006
1091
|
# export means the function will be added to the `warp` module namespace
|
|
1007
1092
|
# so that users can call it directly from the Python interpreter
|
|
1008
|
-
if export
|
|
1093
|
+
if export:
|
|
1009
1094
|
if hasattr(warp, key):
|
|
1010
1095
|
# check that we haven't already created something at this location
|
|
1011
1096
|
# if it's just an overload stub for auto-complete then overwrite it
|
|
@@ -1355,7 +1440,7 @@ class Module:
|
|
|
1355
1440
|
ch.update(bytes(s, "utf-8"))
|
|
1356
1441
|
if func.custom_replay_func:
|
|
1357
1442
|
s = func.custom_replay_func.adj.source
|
|
1358
|
-
|
|
1443
|
+
|
|
1359
1444
|
# cache func arg types
|
|
1360
1445
|
for arg, arg_type in func.adj.arg_types.items():
|
|
1361
1446
|
s = f"{arg}: {get_type_name(arg_type)}"
|
|
@@ -3409,7 +3494,7 @@ def launch(
|
|
|
3409
3494
|
device = runtime.get_device(device)
|
|
3410
3495
|
|
|
3411
3496
|
# check function is a Kernel
|
|
3412
|
-
if isinstance(kernel, Kernel)
|
|
3497
|
+
if not isinstance(kernel, Kernel):
|
|
3413
3498
|
raise RuntimeError("Error launching kernel, can only launch functions decorated with @wp.kernel.")
|
|
3414
3499
|
|
|
3415
3500
|
# debugging aid
|
|
@@ -3693,7 +3778,7 @@ def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
|
|
|
3693
3778
|
return get_module(m.__name__).options
|
|
3694
3779
|
|
|
3695
3780
|
|
|
3696
|
-
def capture_begin(device: Devicelike = None, stream=None, force_module_load=
|
|
3781
|
+
def capture_begin(device: Devicelike = None, stream=None, force_module_load=None):
|
|
3697
3782
|
"""Begin capture of a CUDA graph
|
|
3698
3783
|
|
|
3699
3784
|
Captures all subsequent kernel launches and memory operations on CUDA devices.
|
|
@@ -3707,7 +3792,10 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=True
|
|
|
3707
3792
|
|
|
3708
3793
|
"""
|
|
3709
3794
|
|
|
3710
|
-
if
|
|
3795
|
+
if force_module_load is None:
|
|
3796
|
+
force_module_load = warp.config.graph_capture_module_load_default
|
|
3797
|
+
|
|
3798
|
+
if warp.config.verify_cuda:
|
|
3711
3799
|
raise RuntimeError("Cannot use CUDA error verification during graph capture")
|
|
3712
3800
|
|
|
3713
3801
|
if stream is not None:
|
|
@@ -3990,7 +4078,7 @@ def print_function(f, file, noentry=False): # pragma: no cover
|
|
|
3990
4078
|
return True
|
|
3991
4079
|
|
|
3992
4080
|
|
|
3993
|
-
def
|
|
4081
|
+
def export_functions_rst(file): # pragma: no cover
|
|
3994
4082
|
header = (
|
|
3995
4083
|
"..\n"
|
|
3996
4084
|
" Autogenerated File - Do not edit. Run build_docs.py to generate.\n"
|
|
@@ -4031,6 +4119,14 @@ def print_builtins(file): # pragma: no cover
|
|
|
4031
4119
|
print(".. class:: Transformation", file=file)
|
|
4032
4120
|
print(".. class:: Array", file=file)
|
|
4033
4121
|
|
|
4122
|
+
print("\nQuery Types", file=file)
|
|
4123
|
+
print("-------------", file=file)
|
|
4124
|
+
print(".. autoclass:: bvh_query_t", file=file)
|
|
4125
|
+
print(".. autoclass:: hash_grid_query_t", file=file)
|
|
4126
|
+
print(".. autoclass:: mesh_query_aabb_t", file=file)
|
|
4127
|
+
print(".. autoclass:: mesh_query_point_t", file=file)
|
|
4128
|
+
print(".. autoclass:: mesh_query_ray_t", file=file)
|
|
4129
|
+
|
|
4034
4130
|
# build dictionary of all functions by group
|
|
4035
4131
|
groups = {}
|
|
4036
4132
|
|
|
@@ -4114,7 +4210,7 @@ def export_stubs(file): # pragma: no cover
|
|
|
4114
4210
|
|
|
4115
4211
|
return_str = ""
|
|
4116
4212
|
|
|
4117
|
-
if f.export
|
|
4213
|
+
if not f.export or f.hidden: # or f.generic:
|
|
4118
4214
|
continue
|
|
4119
4215
|
|
|
4120
4216
|
try:
|
|
@@ -4136,7 +4232,17 @@ def export_stubs(file): # pragma: no cover
|
|
|
4136
4232
|
|
|
4137
4233
|
|
|
4138
4234
|
def export_builtins(file: io.TextIOBase): # pragma: no cover
|
|
4139
|
-
def
|
|
4235
|
+
def ctype_arg_str(t):
|
|
4236
|
+
if isinstance(t, int):
|
|
4237
|
+
return "int"
|
|
4238
|
+
elif isinstance(t, float):
|
|
4239
|
+
return "float"
|
|
4240
|
+
elif t in warp.types.vector_types:
|
|
4241
|
+
return f"{t.__name__}&"
|
|
4242
|
+
else:
|
|
4243
|
+
return t.__name__
|
|
4244
|
+
|
|
4245
|
+
def ctype_ret_str(t):
|
|
4140
4246
|
if isinstance(t, int):
|
|
4141
4247
|
return "int"
|
|
4142
4248
|
elif isinstance(t, float):
|
|
@@ -4149,7 +4255,7 @@ def export_builtins(file: io.TextIOBase): # pragma: no cover
|
|
|
4149
4255
|
|
|
4150
4256
|
for k, g in builtin_functions.items():
|
|
4151
4257
|
for f in g.overloads:
|
|
4152
|
-
if f.export
|
|
4258
|
+
if not f.export or f.generic:
|
|
4153
4259
|
continue
|
|
4154
4260
|
|
|
4155
4261
|
simple = True
|
|
@@ -4163,7 +4269,7 @@ def export_builtins(file: io.TextIOBase): # pragma: no cover
|
|
|
4163
4269
|
if not simple or f.variadic:
|
|
4164
4270
|
continue
|
|
4165
4271
|
|
|
4166
|
-
args = ", ".join(f"{
|
|
4272
|
+
args = ", ".join(f"{ctype_arg_str(v)} {k}" for k, v in f.input_types.items())
|
|
4167
4273
|
params = ", ".join(f.input_types.keys())
|
|
4168
4274
|
|
|
4169
4275
|
return_type = ""
|
|
@@ -4171,7 +4277,7 @@ def export_builtins(file: io.TextIOBase): # pragma: no cover
|
|
|
4171
4277
|
try:
|
|
4172
4278
|
# todo: construct a default value for each of the functions args
|
|
4173
4279
|
# so we can generate the return type for overloaded functions
|
|
4174
|
-
return_type =
|
|
4280
|
+
return_type = ctype_ret_str(f.value_func(None, None, None))
|
|
4175
4281
|
except Exception:
|
|
4176
4282
|
continue
|
|
4177
4283
|
|
warp/fem/__init__.py
CHANGED
|
@@ -2,12 +2,12 @@ from .geometry import Geometry, Grid2D, Trimesh2D, Quadmesh2D, Grid3D, Tetmesh,
|
|
|
2
2
|
from .geometry import GeometryPartition, LinearGeometryPartition, ExplicitGeometryPartition
|
|
3
3
|
|
|
4
4
|
from .space import FunctionSpace, make_polynomial_space, ElementBasis
|
|
5
|
-
from .space import BasisSpace, make_polynomial_basis_space, make_collocated_function_space
|
|
5
|
+
from .space import BasisSpace, PointBasisSpace, make_polynomial_basis_space, make_collocated_function_space
|
|
6
6
|
from .space import DofMapper, SkewSymmetricTensorMapper, SymmetricTensorMapper
|
|
7
7
|
from .space import SpaceTopology, SpacePartition, SpaceRestriction, make_space_partition, make_space_restriction
|
|
8
8
|
|
|
9
9
|
from .domain import GeometryDomain, Cells, Sides, BoundarySides, FrontierSides
|
|
10
|
-
from .quadrature import Quadrature, RegularQuadrature, NodalQuadrature, PicQuadrature
|
|
10
|
+
from .quadrature import Quadrature, RegularQuadrature, NodalQuadrature, ExplicitQuadrature, PicQuadrature
|
|
11
11
|
from .polynomial import Polynomial
|
|
12
12
|
|
|
13
13
|
from .field import FieldLike, DiscreteField, make_test, make_trial, make_restriction
|
warp/fem/cache.py
CHANGED
|
@@ -95,6 +95,7 @@ def dynamic_struct(suffix: str, use_qualified_name=False):
|
|
|
95
95
|
def get_integrand_function(
|
|
96
96
|
integrand: "warp.fem.operator.Integrand",
|
|
97
97
|
suffix: str,
|
|
98
|
+
func=None,
|
|
98
99
|
annotations=None,
|
|
99
100
|
code_transformers=[],
|
|
100
101
|
):
|
|
@@ -102,7 +103,7 @@ def get_integrand_function(
|
|
|
102
103
|
|
|
103
104
|
if key not in _func_cache:
|
|
104
105
|
_func_cache[key] = wp.Function(
|
|
105
|
-
func=integrand.func,
|
|
106
|
+
func=integrand.func if func is None else func,
|
|
106
107
|
key=key,
|
|
107
108
|
namespace="",
|
|
108
109
|
module=integrand.module,
|
warp/fem/field/nodal_field.py
CHANGED
|
@@ -84,15 +84,14 @@ class NodalFieldBase(DiscreteField):
|
|
|
84
84
|
if not self.gradient_valid():
|
|
85
85
|
return None
|
|
86
86
|
|
|
87
|
-
@cache.dynamic_func(suffix=self.name
|
|
88
|
-
def
|
|
87
|
+
@cache.dynamic_func(suffix=self.name)
|
|
88
|
+
def eval_grad_inner_ref_space(args: self.ElementEvalArg, s: Sample):
|
|
89
89
|
res = utils.generalized_outer(
|
|
90
90
|
self._read_node_value(args, s.element_index, 0),
|
|
91
91
|
self.space.element_inner_weight_gradient(
|
|
92
92
|
args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, 0
|
|
93
93
|
),
|
|
94
94
|
)
|
|
95
|
-
|
|
96
95
|
for k in range(1, NODES_PER_ELEMENT):
|
|
97
96
|
res += utils.generalized_outer(
|
|
98
97
|
self._read_node_value(args, s.element_index, k),
|
|
@@ -100,14 +99,15 @@ class NodalFieldBase(DiscreteField):
|
|
|
100
99
|
args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
|
|
101
100
|
),
|
|
102
101
|
)
|
|
103
|
-
|
|
104
|
-
if world_space:
|
|
105
|
-
grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
|
|
106
|
-
return utils.apply_right(res, grad_transform)
|
|
107
|
-
|
|
108
102
|
return res
|
|
109
103
|
|
|
110
|
-
|
|
104
|
+
@cache.dynamic_func(suffix=self.name)
|
|
105
|
+
def eval_grad_inner_world_space(args: self.ElementEvalArg, s: Sample):
|
|
106
|
+
grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
|
|
107
|
+
res = eval_grad_inner_ref_space(args, s)
|
|
108
|
+
return utils.apply_right(res, grad_transform)
|
|
109
|
+
|
|
110
|
+
return eval_grad_inner_world_space if world_space else eval_grad_inner_ref_space
|
|
111
111
|
|
|
112
112
|
def _make_eval_div_inner(self):
|
|
113
113
|
NODES_PER_ELEMENT = self.space.topology.NODES_PER_ELEMENT
|
|
@@ -173,8 +173,8 @@ class NodalFieldBase(DiscreteField):
|
|
|
173
173
|
if not self.gradient_valid():
|
|
174
174
|
return None
|
|
175
175
|
|
|
176
|
-
@cache.dynamic_func(suffix=self.name
|
|
177
|
-
def
|
|
176
|
+
@cache.dynamic_func(suffix=self.name)
|
|
177
|
+
def eval_grad_outer_ref_space(args: self.ElementEvalArg, s: Sample):
|
|
178
178
|
res = utils.generalized_outer(
|
|
179
179
|
self._read_node_value(args, s.element_index, 0),
|
|
180
180
|
self.space.element_outer_weight_gradient(
|
|
@@ -188,14 +188,15 @@ class NodalFieldBase(DiscreteField):
|
|
|
188
188
|
args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
|
|
189
189
|
),
|
|
190
190
|
)
|
|
191
|
-
|
|
192
|
-
if world_space:
|
|
193
|
-
grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
|
|
194
|
-
return utils.apply_right(res, grad_transform)
|
|
195
|
-
|
|
196
191
|
return res
|
|
197
192
|
|
|
198
|
-
|
|
193
|
+
@cache.dynamic_func(suffix=self.name)
|
|
194
|
+
def eval_grad_outer_world_space(args: self.ElementEvalArg, s: Sample):
|
|
195
|
+
grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
|
|
196
|
+
res = eval_grad_outer_ref_space(args, s)
|
|
197
|
+
return utils.apply_right(res, grad_transform)
|
|
198
|
+
|
|
199
|
+
return eval_grad_outer_world_space if world_space else eval_grad_outer_ref_space
|
|
199
200
|
|
|
200
201
|
def _make_eval_div_outer(self):
|
|
201
202
|
NODES_PER_ELEMENT = self.space.topology.NODES_PER_ELEMENT
|
warp/fem/geometry/hexmesh.py
CHANGED
|
@@ -1,11 +1,16 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
|
-
import warp as wp
|
|
3
2
|
|
|
4
|
-
|
|
5
|
-
from warp.fem.cache import
|
|
3
|
+
import warp as wp
|
|
4
|
+
from warp.fem.cache import (
|
|
5
|
+
TemporaryStore,
|
|
6
|
+
borrow_temporary,
|
|
7
|
+
borrow_temporary_like,
|
|
8
|
+
cached_arg_value,
|
|
9
|
+
)
|
|
10
|
+
from warp.fem.types import OUTSIDE, Coords, ElementIndex, Sample, make_free_sample
|
|
6
11
|
|
|
12
|
+
from .element import Cube, Square
|
|
7
13
|
from .geometry import Geometry
|
|
8
|
-
from .element import Square, Cube
|
|
9
14
|
|
|
10
15
|
|
|
11
16
|
@wp.struct
|
|
@@ -493,7 +498,7 @@ class Hexmesh(Geometry):
|
|
|
493
498
|
wp.copy(
|
|
494
499
|
dest=face_count.array, src=vertex_unique_face_offsets.array, src_offset=self.vertex_count() - 1, count=1
|
|
495
500
|
)
|
|
496
|
-
wp.synchronize_stream(wp.get_stream())
|
|
501
|
+
wp.synchronize_stream(wp.get_stream(device))
|
|
497
502
|
face_count = int(face_count.array.numpy()[0])
|
|
498
503
|
else:
|
|
499
504
|
face_count = int(vertex_unique_face_offsets.array.numpy()[self.vertex_count() - 1])
|
|
@@ -603,7 +608,7 @@ class Hexmesh(Geometry):
|
|
|
603
608
|
src_offset=self.vertex_count() - 1,
|
|
604
609
|
count=1,
|
|
605
610
|
)
|
|
606
|
-
wp.synchronize_stream(wp.get_stream())
|
|
611
|
+
wp.synchronize_stream(wp.get_stream(device))
|
|
607
612
|
self._edge_count = int(edge_count.array.numpy()[0])
|
|
608
613
|
else:
|
|
609
614
|
self._edge_count = int(vertex_unique_edge_offsets.array.numpy()[self.vertex_count() - 1])
|