warp-lang 1.0.0b5__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. docs/conf.py +3 -4
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/example_dem.py +28 -26
  6. examples/example_diffray.py +37 -30
  7. examples/example_fluid.py +7 -3
  8. examples/example_jacobian_ik.py +1 -1
  9. examples/example_mesh_intersect.py +10 -7
  10. examples/example_nvdb.py +3 -3
  11. examples/example_render_opengl.py +19 -10
  12. examples/example_sim_cartpole.py +9 -5
  13. examples/example_sim_cloth.py +29 -25
  14. examples/example_sim_fk_grad.py +2 -2
  15. examples/example_sim_fk_grad_torch.py +3 -3
  16. examples/example_sim_grad_bounce.py +11 -8
  17. examples/example_sim_grad_cloth.py +12 -9
  18. examples/example_sim_granular.py +2 -2
  19. examples/example_sim_granular_collision_sdf.py +13 -13
  20. examples/example_sim_neo_hookean.py +3 -3
  21. examples/example_sim_particle_chain.py +2 -2
  22. examples/example_sim_quadruped.py +8 -5
  23. examples/example_sim_rigid_chain.py +8 -5
  24. examples/example_sim_rigid_contact.py +13 -10
  25. examples/example_sim_rigid_fem.py +2 -2
  26. examples/example_sim_rigid_gyroscopic.py +2 -2
  27. examples/example_sim_rigid_kinematics.py +1 -1
  28. examples/example_sim_trajopt.py +3 -2
  29. examples/fem/example_apic_fluid.py +5 -7
  30. examples/fem/example_diffusion_mgpu.py +18 -16
  31. warp/__init__.py +3 -2
  32. warp/bin/warp.so +0 -0
  33. warp/build_dll.py +29 -9
  34. warp/builtins.py +206 -7
  35. warp/codegen.py +58 -38
  36. warp/config.py +3 -1
  37. warp/context.py +234 -128
  38. warp/fem/__init__.py +2 -2
  39. warp/fem/cache.py +2 -1
  40. warp/fem/field/nodal_field.py +18 -17
  41. warp/fem/geometry/hexmesh.py +11 -6
  42. warp/fem/geometry/quadmesh_2d.py +16 -12
  43. warp/fem/geometry/tetmesh.py +19 -8
  44. warp/fem/geometry/trimesh_2d.py +18 -7
  45. warp/fem/integrate.py +341 -196
  46. warp/fem/quadrature/__init__.py +1 -1
  47. warp/fem/quadrature/pic_quadrature.py +138 -53
  48. warp/fem/quadrature/quadrature.py +81 -9
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_space.py +169 -51
  51. warp/fem/space/grid_2d_function_space.py +2 -2
  52. warp/fem/space/grid_3d_function_space.py +2 -2
  53. warp/fem/space/hexmesh_function_space.py +2 -2
  54. warp/fem/space/partition.py +9 -6
  55. warp/fem/space/quadmesh_2d_function_space.py +2 -2
  56. warp/fem/space/shape/cube_shape_function.py +27 -15
  57. warp/fem/space/shape/square_shape_function.py +29 -18
  58. warp/fem/space/tetmesh_function_space.py +2 -2
  59. warp/fem/space/topology.py +10 -0
  60. warp/fem/space/trimesh_2d_function_space.py +2 -2
  61. warp/fem/utils.py +10 -5
  62. warp/native/array.h +49 -8
  63. warp/native/builtin.h +31 -14
  64. warp/native/cuda_util.cpp +8 -3
  65. warp/native/cuda_util.h +1 -0
  66. warp/native/exports.h +1177 -1108
  67. warp/native/intersect.h +4 -4
  68. warp/native/intersect_adj.h +8 -8
  69. warp/native/mat.h +65 -6
  70. warp/native/mesh.h +126 -5
  71. warp/native/quat.h +28 -4
  72. warp/native/vec.h +76 -14
  73. warp/native/warp.cu +1 -6
  74. warp/render/render_opengl.py +261 -109
  75. warp/sim/import_mjcf.py +13 -7
  76. warp/sim/import_urdf.py +14 -14
  77. warp/sim/inertia.py +17 -18
  78. warp/sim/model.py +67 -67
  79. warp/sim/render.py +1 -1
  80. warp/sparse.py +6 -6
  81. warp/stubs.py +19 -81
  82. warp/tape.py +1 -1
  83. warp/tests/__main__.py +3 -6
  84. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  85. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  86. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  87. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  88. warp/tests/aux_test_unresolved_func.py +14 -0
  89. warp/tests/aux_test_unresolved_symbol.py +14 -0
  90. warp/tests/{test_kinematics.py → disabled_kinematics.py} +10 -12
  91. warp/tests/run_coverage_serial.py +31 -0
  92. warp/tests/test_adam.py +102 -106
  93. warp/tests/test_arithmetic.py +39 -40
  94. warp/tests/test_array.py +46 -48
  95. warp/tests/test_array_reduce.py +25 -19
  96. warp/tests/test_atomic.py +62 -26
  97. warp/tests/test_bool.py +16 -11
  98. warp/tests/test_builtins_resolution.py +1292 -0
  99. warp/tests/test_bvh.py +9 -12
  100. warp/tests/test_closest_point_edge_edge.py +53 -57
  101. warp/tests/test_codegen.py +164 -134
  102. warp/tests/test_compile_consts.py +13 -19
  103. warp/tests/test_conditional.py +30 -32
  104. warp/tests/test_copy.py +9 -12
  105. warp/tests/test_ctypes.py +90 -98
  106. warp/tests/test_dense.py +20 -14
  107. warp/tests/test_devices.py +34 -35
  108. warp/tests/test_dlpack.py +74 -75
  109. warp/tests/test_examples.py +215 -97
  110. warp/tests/test_fabricarray.py +15 -21
  111. warp/tests/test_fast_math.py +14 -11
  112. warp/tests/test_fem.py +280 -97
  113. warp/tests/test_fp16.py +19 -15
  114. warp/tests/test_func.py +177 -194
  115. warp/tests/test_generics.py +71 -77
  116. warp/tests/test_grad.py +83 -32
  117. warp/tests/test_grad_customs.py +7 -9
  118. warp/tests/test_hash_grid.py +6 -10
  119. warp/tests/test_import.py +9 -23
  120. warp/tests/test_indexedarray.py +19 -21
  121. warp/tests/test_intersect.py +15 -9
  122. warp/tests/test_large.py +17 -19
  123. warp/tests/test_launch.py +14 -17
  124. warp/tests/test_lerp.py +63 -63
  125. warp/tests/test_lvalue.py +84 -35
  126. warp/tests/test_marching_cubes.py +9 -13
  127. warp/tests/test_mat.py +388 -3004
  128. warp/tests/test_mat_lite.py +9 -12
  129. warp/tests/test_mat_scalar_ops.py +2889 -0
  130. warp/tests/test_math.py +10 -11
  131. warp/tests/test_matmul.py +104 -100
  132. warp/tests/test_matmul_lite.py +72 -98
  133. warp/tests/test_mesh.py +35 -32
  134. warp/tests/test_mesh_query_aabb.py +18 -25
  135. warp/tests/test_mesh_query_point.py +39 -23
  136. warp/tests/test_mesh_query_ray.py +9 -21
  137. warp/tests/test_mlp.py +8 -9
  138. warp/tests/test_model.py +89 -93
  139. warp/tests/test_modules_lite.py +15 -25
  140. warp/tests/test_multigpu.py +87 -114
  141. warp/tests/test_noise.py +10 -12
  142. warp/tests/test_operators.py +14 -21
  143. warp/tests/test_options.py +10 -11
  144. warp/tests/test_pinned.py +16 -18
  145. warp/tests/test_print.py +16 -20
  146. warp/tests/test_quat.py +121 -88
  147. warp/tests/test_rand.py +12 -13
  148. warp/tests/test_reload.py +27 -32
  149. warp/tests/test_rounding.py +7 -10
  150. warp/tests/test_runlength_encode.py +105 -106
  151. warp/tests/test_smoothstep.py +8 -9
  152. warp/tests/test_snippet.py +13 -22
  153. warp/tests/test_sparse.py +30 -29
  154. warp/tests/test_spatial.py +179 -174
  155. warp/tests/test_streams.py +100 -107
  156. warp/tests/test_struct.py +98 -67
  157. warp/tests/test_tape.py +11 -17
  158. warp/tests/test_torch.py +89 -86
  159. warp/tests/test_transient_module.py +9 -12
  160. warp/tests/test_types.py +328 -50
  161. warp/tests/test_utils.py +217 -218
  162. warp/tests/test_vec.py +133 -2133
  163. warp/tests/test_vec_lite.py +8 -11
  164. warp/tests/test_vec_scalar_ops.py +2099 -0
  165. warp/tests/test_volume.py +391 -382
  166. warp/tests/test_volume_write.py +122 -135
  167. warp/tests/unittest_serial.py +35 -0
  168. warp/tests/unittest_suites.py +291 -0
  169. warp/tests/{test_base.py → unittest_utils.py} +138 -25
  170. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  171. warp/tests/{test_debug.py → walkthough_debug.py} +2 -15
  172. warp/thirdparty/unittest_parallel.py +257 -54
  173. warp/types.py +119 -98
  174. warp/utils.py +14 -0
  175. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/METADATA +2 -1
  176. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/RECORD +182 -178
  177. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  178. warp/tests/test_all.py +0 -239
  179. warp/tests/test_conditional_unequal_types_kernels.py +0 -14
  180. warp/tests/test_coverage.py +0 -38
  181. warp/tests/test_unresolved_func.py +0 -7
  182. warp/tests/test_unresolved_symbol.py +0 -7
  183. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  184. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  185. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  186. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  187. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -79,6 +79,7 @@ class Function:
79
79
  overloaded_annotations=None,
80
80
  code_transformers=[],
81
81
  skip_adding_overload=False,
82
+ require_original_output_arg=False,
82
83
  ):
83
84
  self.func = func # points to Python function decorated with @wp.func, may be None for builtins
84
85
  self.key = key
@@ -97,6 +98,7 @@ class Function:
97
98
  self.native_snippet = native_snippet
98
99
  self.adj_native_snippet = adj_native_snippet
99
100
  self.custom_grad_func = None
101
+ self.require_original_output_arg = require_original_output_arg
100
102
 
101
103
  if initializer_list_func is None:
102
104
  self.initializer_list_func = lambda x, y: False
@@ -176,112 +178,16 @@ class Function:
176
178
  # from within a kernel (experimental).
177
179
 
178
180
  if self.is_builtin() and self.mangled_name:
179
- for f in self.overloads:
180
- if f.generic:
181
+ # For each of this function's existing overloads, we attempt to pack
182
+ # the given arguments into the C types expected by the corresponding
183
+ # parameters, and we rinse and repeat until we get a match.
184
+ for overload in self.overloads:
185
+ if overload.generic:
181
186
  continue
182
187
 
183
- # try and find builtin in the warp.dll
184
- if not hasattr(warp.context.runtime.core, f.mangled_name):
185
- raise RuntimeError(
186
- f"Couldn't find function {self.key} with mangled name {f.mangled_name} in the Warp native library"
187
- )
188
-
189
- try:
190
- # try and pack args into what the function expects
191
- params = []
192
- for i, (arg_name, arg_type) in enumerate(f.input_types.items()):
193
- a = args[i]
194
-
195
- # try to convert to a value type (vec3, mat33, etc)
196
- if issubclass(arg_type, ctypes.Array):
197
- # wrap the arg_type (which is an ctypes.Array) in a structure
198
- # to ensure parameter is passed to the .dll by value rather than reference
199
- class ValueArg(ctypes.Structure):
200
- _fields_ = [("value", arg_type)]
201
-
202
- x = ValueArg()
203
-
204
- # force conversion to ndarray first (handles tuple / list, Gf.Vec3 case)
205
- if isinstance(a, ctypes.Array) is False:
206
- # assume you want the float32 version of the function so it doesn't just
207
- # grab an override for a random data type:
208
- if arg_type._type_ != ctypes.c_float:
209
- raise RuntimeError(
210
- f"Error calling function '{f.key}', parameter for argument '{arg_name}' does not have c_float type."
211
- )
212
-
213
- a = np.array(a)
214
-
215
- # flatten to 1D array
216
- v = a.flatten()
217
- if len(v) != arg_type._length_:
218
- raise RuntimeError(
219
- f"Error calling function '{f.key}', parameter for argument '{arg_name}' has length {len(v)}, but expected {arg_type._length_}. Could not convert parameter to {arg_type}."
220
- )
221
-
222
- for i in range(arg_type._length_):
223
- x.value[i] = v[i]
224
-
225
- else:
226
- # already a built-in type, check it matches
227
- if not warp.types.types_equal(type(a), arg_type):
228
- raise RuntimeError(
229
- f"Error calling function '{f.key}', parameter for argument '{arg_name}' has type '{type(a)}' but expected '{arg_type}'"
230
- )
231
-
232
- if isinstance(a, arg_type):
233
- x.value = a
234
- else:
235
- # Cast the value to its argument type to make sure that it can be assigned to the field of the `ValueArg` struct.
236
- # This could error otherwise when, for example, the field type is set to `vec3i` while the value is of type
237
- # `vector(length=3, dtype=int)`, even though both types are semantically identical.
238
- x.value = arg_type(a)
239
-
240
- params.append(x)
241
-
242
- else:
243
- try:
244
- # try to pack as a scalar type
245
- params.append(arg_type._type_(a))
246
- except Exception:
247
- raise RuntimeError(
248
- f"Error calling function {f.key}, unable to pack function parameter type {type(a)} for param {arg_name}, expected {arg_type}"
249
- )
250
-
251
- # returns the corresponding ctype for a scalar or vector warp type
252
- def type_ctype(dtype):
253
- if dtype == float:
254
- return ctypes.c_float
255
- elif dtype == int:
256
- return ctypes.c_int32
257
- elif issubclass(dtype, ctypes.Array):
258
- return dtype
259
- elif issubclass(dtype, ctypes.Structure):
260
- return dtype
261
- else:
262
- # scalar type
263
- return dtype._type_
264
-
265
- value_type = type_ctype(f.value_func(None, None, None))
266
-
267
- # construct return value (passed by address)
268
- ret = value_type()
269
- ret_addr = ctypes.c_void_p(ctypes.addressof(ret))
270
-
271
- params.append(ret_addr)
272
-
273
- c_func = getattr(warp.context.runtime.core, f.mangled_name)
274
- c_func(*params)
275
-
276
- if issubclass(value_type, ctypes.Array) or issubclass(value_type, ctypes.Structure):
277
- # return vector types as ctypes
278
- return ret
279
-
280
- # return scalar types as int/float
281
- return ret.value
282
- except Exception:
283
- # couldn't pack values to match this overload
284
- continue
188
+ success, return_value = call_builtin(overload, *args)
189
+ if success:
190
+ return return_value
285
191
 
286
192
  # overload resolution or call failed
287
193
  raise RuntimeError(
@@ -289,7 +195,7 @@ class Function:
289
195
  f"the arguments '{', '.join(type(x).__name__ for x in args)}'"
290
196
  )
291
197
 
292
- elif hasattr(self, "user_overloads") and len(self.user_overloads):
198
+ if hasattr(self, "user_overloads") and len(self.user_overloads):
293
199
  # user-defined function with overloads
294
200
 
295
201
  if len(kwargs):
@@ -298,28 +204,26 @@ class Function:
298
204
  )
299
205
 
300
206
  # try and find a matching overload
301
- for f in self.user_overloads.values():
302
- if len(f.input_types) != len(args):
207
+ for overload in self.user_overloads.values():
208
+ if len(overload.input_types) != len(args):
303
209
  continue
304
- template_types = list(f.input_types.values())
305
- arg_names = list(f.input_types.keys())
210
+ template_types = list(overload.input_types.values())
211
+ arg_names = list(overload.input_types.keys())
306
212
  try:
307
213
  # attempt to unify argument types with function template types
308
214
  warp.types.infer_argument_types(args, template_types, arg_names)
309
- return f.func(*args)
215
+ return overload.func(*args)
310
216
  except Exception:
311
217
  continue
312
218
 
313
219
  raise RuntimeError(f"Error calling function '{self.key}', no overload found for arguments {args}")
314
220
 
315
- else:
316
- # user-defined function with no overloads
317
-
318
- if self.func is None:
319
- raise RuntimeError(f"Error calling function '{self.key}', function is undefined")
221
+ # user-defined function with no overloads
222
+ if self.func is None:
223
+ raise RuntimeError(f"Error calling function '{self.key}', function is undefined")
320
224
 
321
- # this function has no overloads, call it like a plain Python function
322
- return self.func(*args, **kwargs)
225
+ # this function has no overloads, call it like a plain Python function
226
+ return self.func(*args, **kwargs)
323
227
 
324
228
  def is_builtin(self):
325
229
  return self.func is None
@@ -436,6 +340,184 @@ class Function:
436
340
  return f"<Function {self.key}({inputs_str})>"
437
341
 
438
342
 
343
+ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
344
+ uses_non_warp_array_type = False
345
+
346
+ # Retrieve the built-in function from Warp's dll.
347
+ c_func = getattr(warp.context.runtime.core, func.mangled_name)
348
+
349
+ # Try gathering the parameters that the function expects and pack them
350
+ # into their corresponding C types.
351
+ c_params = []
352
+ for i, (_, arg_type) in enumerate(func.input_types.items()):
353
+ param = params[i]
354
+
355
+ try:
356
+ iter(param)
357
+ except TypeError:
358
+ is_array = False
359
+ else:
360
+ is_array = True
361
+
362
+ if is_array:
363
+ if not issubclass(arg_type, ctypes.Array):
364
+ return (False, None)
365
+
366
+ # The argument expects a built-in Warp type like a vector or a matrix.
367
+
368
+ c_param = None
369
+
370
+ if isinstance(param, ctypes.Array):
371
+ # The given parameter is also a built-in Warp type, so we only need
372
+ # to make sure that it matches with the argument.
373
+ if not warp.types.types_equal(type(param), arg_type):
374
+ return (False, None)
375
+
376
+ if isinstance(param, arg_type):
377
+ c_param = param
378
+ else:
379
+ # Cast the value to its argument type to make sure that it
380
+ # can be assigned to the field of the `Param` struct.
381
+ # This could error otherwise when, for example, the field type
382
+ # is set to `vec3i` while the value is of type `vector(length=3, dtype=int)`,
383
+ # even though both types are semantically identical.
384
+ c_param = arg_type(param)
385
+ else:
386
+ # Flatten the parameter values into a flat 1-D array.
387
+ arr = []
388
+ ndim = 1
389
+ stack = [(0, param)]
390
+ while stack:
391
+ depth, elem = stack.pop(0)
392
+ try:
393
+ # If `elem` is a sequence, then it should be possible
394
+ # to add its elements to the stack for later processing.
395
+ stack.extend((depth + 1, x) for x in elem)
396
+ except TypeError:
397
+ # Since `elem` doesn't seem to be a sequence,
398
+ # we must have a leaf value that we need to add to our
399
+ # resulting array.
400
+ arr.append(elem)
401
+ ndim = max(depth, ndim)
402
+
403
+ assert ndim > 0
404
+
405
+ # Ensure that if the given parameter value is, say, a 2-D array,
406
+ # then we try to resolve it against a matrix argument rather than
407
+ # a vector.
408
+ if ndim > len(arg_type._shape_):
409
+ return (False, None)
410
+
411
+ elem_count = len(arr)
412
+ if elem_count != arg_type._length_:
413
+ return (False, None)
414
+
415
+ # Retrieve the element type of the sequence while ensuring
416
+ # that it's homogeneous.
417
+ elem_type = type(arr[0])
418
+ for i in range(1, elem_count):
419
+ if type(arr[i]) is not elem_type:
420
+ raise ValueError("All array elements must share the same type.")
421
+
422
+ expected_elem_type = arg_type._wp_scalar_type_
423
+ if not (
424
+ elem_type is expected_elem_type
425
+ or (elem_type is float and expected_elem_type is warp.types.float32)
426
+ or (elem_type is int and expected_elem_type is warp.types.int32)
427
+ or (
428
+ issubclass(elem_type, np.number)
429
+ and warp.types.np_dtype_to_warp_type[np.dtype(elem_type)] is expected_elem_type
430
+ )
431
+ ):
432
+ # The parameter value has a type not matching the type defined
433
+ # for the corresponding argument.
434
+ return (False, None)
435
+
436
+ if elem_type in warp.types.int_types:
437
+ # Pass the value through the expected integer type
438
+ # in order to evaluate any integer wrapping.
439
+ # For example `uint8(-1)` should result in the value `-255`.
440
+ arr = tuple(elem_type._type_(x.value).value for x in arr)
441
+ elif elem_type in warp.types.float_types:
442
+ # Extract the floating-point values.
443
+ arr = tuple(x.value for x in arr)
444
+
445
+ c_param = arg_type()
446
+ if warp.types.type_is_matrix(arg_type):
447
+ rows, cols = arg_type._shape_
448
+ for i in range(rows):
449
+ idx_start = i * cols
450
+ idx_end = idx_start + cols
451
+ c_param[i] = arr[idx_start:idx_end]
452
+ else:
453
+ c_param[:] = arr
454
+
455
+ uses_non_warp_array_type = True
456
+
457
+ c_params.append(ctypes.byref(c_param))
458
+ else:
459
+ if issubclass(arg_type, ctypes.Array):
460
+ return (False, None)
461
+
462
+ if not (
463
+ isinstance(param, arg_type)
464
+ or (type(param) is float and arg_type is warp.types.float32)
465
+ or (type(param) is int and arg_type is warp.types.int32)
466
+ or warp.types.np_dtype_to_warp_type.get(getattr(param, "dtype", None)) is arg_type
467
+ ):
468
+ return (False, None)
469
+
470
+ if type(param) in warp.types.scalar_types:
471
+ param = param.value
472
+
473
+ # try to pack as a scalar type
474
+ if arg_type == warp.types.float16:
475
+ c_params.append(arg_type._type_(warp.types.float_to_half_bits(param)))
476
+ else:
477
+ c_params.append(arg_type._type_(param))
478
+
479
+ # returns the corresponding ctype for a scalar or vector warp type
480
+ value_type = func.value_func(None, None, None)
481
+ if value_type == float:
482
+ value_ctype = ctypes.c_float
483
+ elif value_type == int:
484
+ value_ctype = ctypes.c_int32
485
+ elif issubclass(value_type, (ctypes.Array, ctypes.Structure)):
486
+ value_ctype = value_type
487
+ else:
488
+ # scalar type
489
+ value_ctype = value_type._type_
490
+
491
+ # construct return value (passed by address)
492
+ ret = value_ctype()
493
+ ret_addr = ctypes.c_void_p(ctypes.addressof(ret))
494
+ c_params.append(ret_addr)
495
+
496
+ # Call the built-in function from Warp's dll.
497
+ c_func(*c_params)
498
+
499
+ # TODO: uncomment when we have a way to print warning messages only once.
500
+ # if uses_non_warp_array_type:
501
+ # warp.utils.warn(
502
+ # "Support for built-in functions called with non-Warp array types, "
503
+ # "such as lists, tuples, NumPy arrays, and others, will be dropped "
504
+ # "in the future. Use a Warp type such as `wp.vec`, `wp.mat`, "
505
+ # "`wp.quat`, or `wp.transform`.",
506
+ # DeprecationWarning,
507
+ # stacklevel=3
508
+ # )
509
+
510
+ if issubclass(value_ctype, ctypes.Array) or issubclass(value_ctype, ctypes.Structure):
511
+ # return vector types as ctypes
512
+ return (True, ret)
513
+
514
+ if value_type == warp.types.float16:
515
+ return (True, warp.types.half_bits_to_float(ret.value))
516
+
517
+ # return scalar types as int/float
518
+ return (True, ret.value)
519
+
520
+
439
521
  class KernelHooks:
440
522
  def __init__(self, forward, backward):
441
523
  self.forward = forward
@@ -852,6 +934,7 @@ def add_builtin(
852
934
  missing_grad=False,
853
935
  native_func=None,
854
936
  defaults=None,
937
+ require_original_output_arg=False,
855
938
  ):
856
939
  # wrap simple single-type functions with a value_func()
857
940
  if value_func is None:
@@ -976,6 +1059,7 @@ def add_builtin(
976
1059
  hidden=True,
977
1060
  skip_replay=skip_replay,
978
1061
  missing_grad=missing_grad,
1062
+ require_original_output_arg=require_original_output_arg,
979
1063
  )
980
1064
 
981
1065
  func = Function(
@@ -996,6 +1080,7 @@ def add_builtin(
996
1080
  generic=generic,
997
1081
  native_func=native_func,
998
1082
  defaults=defaults,
1083
+ require_original_output_arg=require_original_output_arg,
999
1084
  )
1000
1085
 
1001
1086
  if key in builtin_functions:
@@ -1005,7 +1090,7 @@ def add_builtin(
1005
1090
 
1006
1091
  # export means the function will be added to the `warp` module namespace
1007
1092
  # so that users can call it directly from the Python interpreter
1008
- if export is True:
1093
+ if export:
1009
1094
  if hasattr(warp, key):
1010
1095
  # check that we haven't already created something at this location
1011
1096
  # if it's just an overload stub for auto-complete then overwrite it
@@ -1355,7 +1440,7 @@ class Module:
1355
1440
  ch.update(bytes(s, "utf-8"))
1356
1441
  if func.custom_replay_func:
1357
1442
  s = func.custom_replay_func.adj.source
1358
-
1443
+
1359
1444
  # cache func arg types
1360
1445
  for arg, arg_type in func.adj.arg_types.items():
1361
1446
  s = f"{arg}: {get_type_name(arg_type)}"
@@ -3409,7 +3494,7 @@ def launch(
3409
3494
  device = runtime.get_device(device)
3410
3495
 
3411
3496
  # check function is a Kernel
3412
- if isinstance(kernel, Kernel) is False:
3497
+ if not isinstance(kernel, Kernel):
3413
3498
  raise RuntimeError("Error launching kernel, can only launch functions decorated with @wp.kernel.")
3414
3499
 
3415
3500
  # debugging aid
@@ -3693,7 +3778,7 @@ def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
3693
3778
  return get_module(m.__name__).options
3694
3779
 
3695
3780
 
3696
- def capture_begin(device: Devicelike = None, stream=None, force_module_load=True):
3781
+ def capture_begin(device: Devicelike = None, stream=None, force_module_load=None):
3697
3782
  """Begin capture of a CUDA graph
3698
3783
 
3699
3784
  Captures all subsequent kernel launches and memory operations on CUDA devices.
@@ -3707,7 +3792,10 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=True
3707
3792
 
3708
3793
  """
3709
3794
 
3710
- if warp.config.verify_cuda is True:
3795
+ if force_module_load is None:
3796
+ force_module_load = warp.config.graph_capture_module_load_default
3797
+
3798
+ if warp.config.verify_cuda:
3711
3799
  raise RuntimeError("Cannot use CUDA error verification during graph capture")
3712
3800
 
3713
3801
  if stream is not None:
@@ -3990,7 +4078,7 @@ def print_function(f, file, noentry=False): # pragma: no cover
3990
4078
  return True
3991
4079
 
3992
4080
 
3993
- def print_builtins(file): # pragma: no cover
4081
+ def export_functions_rst(file): # pragma: no cover
3994
4082
  header = (
3995
4083
  "..\n"
3996
4084
  " Autogenerated File - Do not edit. Run build_docs.py to generate.\n"
@@ -4031,6 +4119,14 @@ def print_builtins(file): # pragma: no cover
4031
4119
  print(".. class:: Transformation", file=file)
4032
4120
  print(".. class:: Array", file=file)
4033
4121
 
4122
+ print("\nQuery Types", file=file)
4123
+ print("-------------", file=file)
4124
+ print(".. autoclass:: bvh_query_t", file=file)
4125
+ print(".. autoclass:: hash_grid_query_t", file=file)
4126
+ print(".. autoclass:: mesh_query_aabb_t", file=file)
4127
+ print(".. autoclass:: mesh_query_point_t", file=file)
4128
+ print(".. autoclass:: mesh_query_ray_t", file=file)
4129
+
4034
4130
  # build dictionary of all functions by group
4035
4131
  groups = {}
4036
4132
 
@@ -4114,7 +4210,7 @@ def export_stubs(file): # pragma: no cover
4114
4210
 
4115
4211
  return_str = ""
4116
4212
 
4117
- if f.export is False or f.hidden is True: # or f.generic:
4213
+ if not f.export or f.hidden: # or f.generic:
4118
4214
  continue
4119
4215
 
4120
4216
  try:
@@ -4136,7 +4232,17 @@ def export_stubs(file): # pragma: no cover
4136
4232
 
4137
4233
 
4138
4234
  def export_builtins(file: io.TextIOBase): # pragma: no cover
4139
- def ctype_str(t):
4235
+ def ctype_arg_str(t):
4236
+ if isinstance(t, int):
4237
+ return "int"
4238
+ elif isinstance(t, float):
4239
+ return "float"
4240
+ elif t in warp.types.vector_types:
4241
+ return f"{t.__name__}&"
4242
+ else:
4243
+ return t.__name__
4244
+
4245
+ def ctype_ret_str(t):
4140
4246
  if isinstance(t, int):
4141
4247
  return "int"
4142
4248
  elif isinstance(t, float):
@@ -4149,7 +4255,7 @@ def export_builtins(file: io.TextIOBase): # pragma: no cover
4149
4255
 
4150
4256
  for k, g in builtin_functions.items():
4151
4257
  for f in g.overloads:
4152
- if f.export is False or f.generic:
4258
+ if not f.export or f.generic:
4153
4259
  continue
4154
4260
 
4155
4261
  simple = True
@@ -4163,7 +4269,7 @@ def export_builtins(file: io.TextIOBase): # pragma: no cover
4163
4269
  if not simple or f.variadic:
4164
4270
  continue
4165
4271
 
4166
- args = ", ".join(f"{ctype_str(v)} {k}" for k, v in f.input_types.items())
4272
+ args = ", ".join(f"{ctype_arg_str(v)} {k}" for k, v in f.input_types.items())
4167
4273
  params = ", ".join(f.input_types.keys())
4168
4274
 
4169
4275
  return_type = ""
@@ -4171,7 +4277,7 @@ def export_builtins(file: io.TextIOBase): # pragma: no cover
4171
4277
  try:
4172
4278
  # todo: construct a default value for each of the functions args
4173
4279
  # so we can generate the return type for overloaded functions
4174
- return_type = ctype_str(f.value_func(None, None, None))
4280
+ return_type = ctype_ret_str(f.value_func(None, None, None))
4175
4281
  except Exception:
4176
4282
  continue
4177
4283
 
warp/fem/__init__.py CHANGED
@@ -2,12 +2,12 @@ from .geometry import Geometry, Grid2D, Trimesh2D, Quadmesh2D, Grid3D, Tetmesh,
2
2
  from .geometry import GeometryPartition, LinearGeometryPartition, ExplicitGeometryPartition
3
3
 
4
4
  from .space import FunctionSpace, make_polynomial_space, ElementBasis
5
- from .space import BasisSpace, make_polynomial_basis_space, make_collocated_function_space
5
+ from .space import BasisSpace, PointBasisSpace, make_polynomial_basis_space, make_collocated_function_space
6
6
  from .space import DofMapper, SkewSymmetricTensorMapper, SymmetricTensorMapper
7
7
  from .space import SpaceTopology, SpacePartition, SpaceRestriction, make_space_partition, make_space_restriction
8
8
 
9
9
  from .domain import GeometryDomain, Cells, Sides, BoundarySides, FrontierSides
10
- from .quadrature import Quadrature, RegularQuadrature, NodalQuadrature, PicQuadrature
10
+ from .quadrature import Quadrature, RegularQuadrature, NodalQuadrature, ExplicitQuadrature, PicQuadrature
11
11
  from .polynomial import Polynomial
12
12
 
13
13
  from .field import FieldLike, DiscreteField, make_test, make_trial, make_restriction
warp/fem/cache.py CHANGED
@@ -95,6 +95,7 @@ def dynamic_struct(suffix: str, use_qualified_name=False):
95
95
  def get_integrand_function(
96
96
  integrand: "warp.fem.operator.Integrand",
97
97
  suffix: str,
98
+ func=None,
98
99
  annotations=None,
99
100
  code_transformers=[],
100
101
  ):
@@ -102,7 +103,7 @@ def get_integrand_function(
102
103
 
103
104
  if key not in _func_cache:
104
105
  _func_cache[key] = wp.Function(
105
- func=integrand.func,
106
+ func=integrand.func if func is None else func,
106
107
  key=key,
107
108
  namespace="",
108
109
  module=integrand.module,
@@ -84,15 +84,14 @@ class NodalFieldBase(DiscreteField):
84
84
  if not self.gradient_valid():
85
85
  return None
86
86
 
87
- @cache.dynamic_func(suffix=self.name + ("W" if world_space else "R"))
88
- def eval_grad_inner(args: self.ElementEvalArg, s: Sample):
87
+ @cache.dynamic_func(suffix=self.name)
88
+ def eval_grad_inner_ref_space(args: self.ElementEvalArg, s: Sample):
89
89
  res = utils.generalized_outer(
90
90
  self._read_node_value(args, s.element_index, 0),
91
91
  self.space.element_inner_weight_gradient(
92
92
  args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, 0
93
93
  ),
94
94
  )
95
-
96
95
  for k in range(1, NODES_PER_ELEMENT):
97
96
  res += utils.generalized_outer(
98
97
  self._read_node_value(args, s.element_index, k),
@@ -100,14 +99,15 @@ class NodalFieldBase(DiscreteField):
100
99
  args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
101
100
  ),
102
101
  )
103
-
104
- if world_space:
105
- grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
106
- return utils.apply_right(res, grad_transform)
107
-
108
102
  return res
109
103
 
110
- return eval_grad_inner
104
+ @cache.dynamic_func(suffix=self.name)
105
+ def eval_grad_inner_world_space(args: self.ElementEvalArg, s: Sample):
106
+ grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
107
+ res = eval_grad_inner_ref_space(args, s)
108
+ return utils.apply_right(res, grad_transform)
109
+
110
+ return eval_grad_inner_world_space if world_space else eval_grad_inner_ref_space
111
111
 
112
112
  def _make_eval_div_inner(self):
113
113
  NODES_PER_ELEMENT = self.space.topology.NODES_PER_ELEMENT
@@ -173,8 +173,8 @@ class NodalFieldBase(DiscreteField):
173
173
  if not self.gradient_valid():
174
174
  return None
175
175
 
176
- @cache.dynamic_func(suffix=self.name + ("W" if world_space else "R"))
177
- def eval_grad_outer(args: self.ElementEvalArg, s: Sample):
176
+ @cache.dynamic_func(suffix=self.name)
177
+ def eval_grad_outer_ref_space(args: self.ElementEvalArg, s: Sample):
178
178
  res = utils.generalized_outer(
179
179
  self._read_node_value(args, s.element_index, 0),
180
180
  self.space.element_outer_weight_gradient(
@@ -188,14 +188,15 @@ class NodalFieldBase(DiscreteField):
188
188
  args.elt_arg, args.eval_arg.space_arg, s.element_index, s.element_coords, k
189
189
  ),
190
190
  )
191
-
192
- if world_space:
193
- grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
194
- return utils.apply_right(res, grad_transform)
195
-
196
191
  return res
197
192
 
198
- return eval_grad_outer
193
+ @cache.dynamic_func(suffix=self.name)
194
+ def eval_grad_outer_world_space(args: self.ElementEvalArg, s: Sample):
195
+ grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
196
+ res = eval_grad_outer_ref_space(args, s)
197
+ return utils.apply_right(res, grad_transform)
198
+
199
+ return eval_grad_outer_world_space if world_space else eval_grad_outer_ref_space
199
200
 
200
201
  def _make_eval_div_outer(self):
201
202
  NODES_PER_ELEMENT = self.space.topology.NODES_PER_ELEMENT
@@ -1,11 +1,16 @@
1
1
  from typing import Optional
2
- import warp as wp
3
2
 
4
- from warp.fem.types import ElementIndex, Coords, Sample, OUTSIDE, make_free_sample
5
- from warp.fem.cache import cached_arg_value, TemporaryStore, borrow_temporary, borrow_temporary_like
3
+ import warp as wp
4
+ from warp.fem.cache import (
5
+ TemporaryStore,
6
+ borrow_temporary,
7
+ borrow_temporary_like,
8
+ cached_arg_value,
9
+ )
10
+ from warp.fem.types import OUTSIDE, Coords, ElementIndex, Sample, make_free_sample
6
11
 
12
+ from .element import Cube, Square
7
13
  from .geometry import Geometry
8
- from .element import Square, Cube
9
14
 
10
15
 
11
16
  @wp.struct
@@ -493,7 +498,7 @@ class Hexmesh(Geometry):
493
498
  wp.copy(
494
499
  dest=face_count.array, src=vertex_unique_face_offsets.array, src_offset=self.vertex_count() - 1, count=1
495
500
  )
496
- wp.synchronize_stream(wp.get_stream())
501
+ wp.synchronize_stream(wp.get_stream(device))
497
502
  face_count = int(face_count.array.numpy()[0])
498
503
  else:
499
504
  face_count = int(vertex_unique_face_offsets.array.numpy()[self.vertex_count() - 1])
@@ -603,7 +608,7 @@ class Hexmesh(Geometry):
603
608
  src_offset=self.vertex_count() - 1,
604
609
  count=1,
605
610
  )
606
- wp.synchronize_stream(wp.get_stream())
611
+ wp.synchronize_stream(wp.get_stream(device))
607
612
  self._edge_count = int(edge_count.array.numpy()[0])
608
613
  else:
609
614
  self._edge_count = int(vertex_unique_edge_offsets.array.numpy()[self.vertex_count() - 1])