warp-lang 1.2.2__py3-none-manylinux2014_x86_64.whl → 1.3.1__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (193) hide show
  1. warp/__init__.py +8 -6
  2. warp/autograd.py +823 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +6 -2
  5. warp/builtins.py +1412 -888
  6. warp/codegen.py +503 -166
  7. warp/config.py +48 -18
  8. warp/context.py +400 -198
  9. warp/dlpack.py +8 -0
  10. warp/examples/assets/bunny.usd +0 -0
  11. warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
  12. warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
  13. warp/examples/benchmarks/benchmark_launches.py +1 -1
  14. warp/examples/core/example_cupy.py +78 -0
  15. warp/examples/fem/example_apic_fluid.py +17 -36
  16. warp/examples/fem/example_burgers.py +9 -18
  17. warp/examples/fem/example_convection_diffusion.py +7 -17
  18. warp/examples/fem/example_convection_diffusion_dg.py +27 -47
  19. warp/examples/fem/example_deformed_geometry.py +11 -22
  20. warp/examples/fem/example_diffusion.py +7 -18
  21. warp/examples/fem/example_diffusion_3d.py +24 -28
  22. warp/examples/fem/example_diffusion_mgpu.py +7 -14
  23. warp/examples/fem/example_magnetostatics.py +190 -0
  24. warp/examples/fem/example_mixed_elasticity.py +111 -80
  25. warp/examples/fem/example_navier_stokes.py +30 -34
  26. warp/examples/fem/example_nonconforming_contact.py +290 -0
  27. warp/examples/fem/example_stokes.py +17 -32
  28. warp/examples/fem/example_stokes_transfer.py +12 -21
  29. warp/examples/fem/example_streamlines.py +350 -0
  30. warp/examples/fem/utils.py +936 -0
  31. warp/fabric.py +5 -2
  32. warp/fem/__init__.py +13 -3
  33. warp/fem/cache.py +161 -11
  34. warp/fem/dirichlet.py +37 -28
  35. warp/fem/domain.py +105 -14
  36. warp/fem/field/__init__.py +14 -3
  37. warp/fem/field/field.py +454 -11
  38. warp/fem/field/nodal_field.py +33 -18
  39. warp/fem/geometry/deformed_geometry.py +50 -15
  40. warp/fem/geometry/hexmesh.py +12 -24
  41. warp/fem/geometry/nanogrid.py +106 -31
  42. warp/fem/geometry/quadmesh_2d.py +6 -11
  43. warp/fem/geometry/tetmesh.py +103 -61
  44. warp/fem/geometry/trimesh_2d.py +98 -47
  45. warp/fem/integrate.py +231 -186
  46. warp/fem/operator.py +14 -9
  47. warp/fem/quadrature/pic_quadrature.py +35 -9
  48. warp/fem/quadrature/quadrature.py +119 -32
  49. warp/fem/space/basis_space.py +98 -22
  50. warp/fem/space/collocated_function_space.py +3 -1
  51. warp/fem/space/function_space.py +7 -2
  52. warp/fem/space/grid_2d_function_space.py +3 -3
  53. warp/fem/space/grid_3d_function_space.py +4 -4
  54. warp/fem/space/hexmesh_function_space.py +3 -2
  55. warp/fem/space/nanogrid_function_space.py +12 -14
  56. warp/fem/space/partition.py +45 -47
  57. warp/fem/space/restriction.py +19 -16
  58. warp/fem/space/shape/cube_shape_function.py +91 -3
  59. warp/fem/space/shape/shape_function.py +7 -0
  60. warp/fem/space/shape/square_shape_function.py +32 -0
  61. warp/fem/space/shape/tet_shape_function.py +11 -7
  62. warp/fem/space/shape/triangle_shape_function.py +10 -1
  63. warp/fem/space/topology.py +116 -42
  64. warp/fem/types.py +8 -1
  65. warp/fem/utils.py +301 -83
  66. warp/native/array.h +16 -0
  67. warp/native/builtin.h +0 -15
  68. warp/native/cuda_util.cpp +14 -6
  69. warp/native/exports.h +1348 -1308
  70. warp/native/quat.h +79 -0
  71. warp/native/rand.h +27 -4
  72. warp/native/sparse.cpp +83 -81
  73. warp/native/sparse.cu +381 -453
  74. warp/native/vec.h +64 -0
  75. warp/native/volume.cpp +40 -49
  76. warp/native/volume_builder.cu +2 -3
  77. warp/native/volume_builder.h +12 -17
  78. warp/native/warp.cu +3 -3
  79. warp/native/warp.h +69 -59
  80. warp/render/render_opengl.py +17 -9
  81. warp/sim/articulation.py +117 -17
  82. warp/sim/collide.py +35 -29
  83. warp/sim/model.py +123 -18
  84. warp/sim/render.py +3 -1
  85. warp/sparse.py +867 -203
  86. warp/stubs.py +312 -541
  87. warp/tape.py +29 -1
  88. warp/tests/disabled_kinematics.py +1 -1
  89. warp/tests/test_adam.py +1 -1
  90. warp/tests/test_arithmetic.py +1 -1
  91. warp/tests/test_array.py +58 -1
  92. warp/tests/test_array_reduce.py +1 -1
  93. warp/tests/test_async.py +1 -1
  94. warp/tests/test_atomic.py +1 -1
  95. warp/tests/test_bool.py +1 -1
  96. warp/tests/test_builtins_resolution.py +1 -1
  97. warp/tests/test_bvh.py +6 -1
  98. warp/tests/test_closest_point_edge_edge.py +1 -1
  99. warp/tests/test_codegen.py +91 -1
  100. warp/tests/test_compile_consts.py +1 -1
  101. warp/tests/test_conditional.py +1 -1
  102. warp/tests/test_copy.py +1 -1
  103. warp/tests/test_ctypes.py +1 -1
  104. warp/tests/test_dense.py +1 -1
  105. warp/tests/test_devices.py +1 -1
  106. warp/tests/test_dlpack.py +1 -1
  107. warp/tests/test_examples.py +33 -4
  108. warp/tests/test_fabricarray.py +5 -2
  109. warp/tests/test_fast_math.py +1 -1
  110. warp/tests/test_fem.py +213 -6
  111. warp/tests/test_fp16.py +1 -1
  112. warp/tests/test_func.py +1 -1
  113. warp/tests/test_future_annotations.py +90 -0
  114. warp/tests/test_generics.py +1 -1
  115. warp/tests/test_grad.py +1 -1
  116. warp/tests/test_grad_customs.py +1 -1
  117. warp/tests/test_grad_debug.py +247 -0
  118. warp/tests/test_hash_grid.py +6 -1
  119. warp/tests/test_implicit_init.py +354 -0
  120. warp/tests/test_import.py +1 -1
  121. warp/tests/test_indexedarray.py +1 -1
  122. warp/tests/test_intersect.py +1 -1
  123. warp/tests/test_jax.py +1 -1
  124. warp/tests/test_large.py +1 -1
  125. warp/tests/test_launch.py +1 -1
  126. warp/tests/test_lerp.py +1 -1
  127. warp/tests/test_linear_solvers.py +1 -1
  128. warp/tests/test_lvalue.py +1 -1
  129. warp/tests/test_marching_cubes.py +5 -2
  130. warp/tests/test_mat.py +34 -35
  131. warp/tests/test_mat_lite.py +2 -1
  132. warp/tests/test_mat_scalar_ops.py +1 -1
  133. warp/tests/test_math.py +1 -1
  134. warp/tests/test_matmul.py +20 -16
  135. warp/tests/test_matmul_lite.py +1 -1
  136. warp/tests/test_mempool.py +1 -1
  137. warp/tests/test_mesh.py +5 -2
  138. warp/tests/test_mesh_query_aabb.py +1 -1
  139. warp/tests/test_mesh_query_point.py +1 -1
  140. warp/tests/test_mesh_query_ray.py +1 -1
  141. warp/tests/test_mlp.py +1 -1
  142. warp/tests/test_model.py +1 -1
  143. warp/tests/test_module_hashing.py +77 -1
  144. warp/tests/test_modules_lite.py +1 -1
  145. warp/tests/test_multigpu.py +1 -1
  146. warp/tests/test_noise.py +1 -1
  147. warp/tests/test_operators.py +1 -1
  148. warp/tests/test_options.py +1 -1
  149. warp/tests/test_overwrite.py +542 -0
  150. warp/tests/test_peer.py +1 -1
  151. warp/tests/test_pinned.py +1 -1
  152. warp/tests/test_print.py +1 -1
  153. warp/tests/test_quat.py +15 -1
  154. warp/tests/test_rand.py +1 -1
  155. warp/tests/test_reload.py +1 -1
  156. warp/tests/test_rounding.py +1 -1
  157. warp/tests/test_runlength_encode.py +1 -1
  158. warp/tests/test_scalar_ops.py +95 -0
  159. warp/tests/test_sim_grad.py +1 -1
  160. warp/tests/test_sim_kinematics.py +1 -1
  161. warp/tests/test_smoothstep.py +1 -1
  162. warp/tests/test_sparse.py +82 -15
  163. warp/tests/test_spatial.py +1 -1
  164. warp/tests/test_special_values.py +2 -11
  165. warp/tests/test_streams.py +11 -1
  166. warp/tests/test_struct.py +1 -1
  167. warp/tests/test_tape.py +1 -1
  168. warp/tests/test_torch.py +194 -1
  169. warp/tests/test_transient_module.py +1 -1
  170. warp/tests/test_types.py +1 -1
  171. warp/tests/test_utils.py +1 -1
  172. warp/tests/test_vec.py +15 -63
  173. warp/tests/test_vec_lite.py +2 -1
  174. warp/tests/test_vec_scalar_ops.py +65 -1
  175. warp/tests/test_verify_fp.py +1 -1
  176. warp/tests/test_volume.py +28 -2
  177. warp/tests/test_volume_write.py +1 -1
  178. warp/tests/unittest_serial.py +1 -1
  179. warp/tests/unittest_suites.py +9 -1
  180. warp/tests/walkthrough_debug.py +1 -1
  181. warp/thirdparty/unittest_parallel.py +2 -5
  182. warp/torch.py +103 -41
  183. warp/types.py +341 -224
  184. warp/utils.py +11 -2
  185. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/METADATA +99 -46
  186. warp_lang-1.3.1.dist-info/RECORD +368 -0
  187. warp/examples/fem/bsr_utils.py +0 -378
  188. warp/examples/fem/mesh_utils.py +0 -133
  189. warp/examples/fem/plot_utils.py +0 -292
  190. warp_lang-1.2.2.dist-info/RECORD +0 -359
  191. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/LICENSE.md +0 -0
  192. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/WHEEL +0 -0
  193. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -18,6 +18,7 @@ import os
18
18
  import platform
19
19
  import sys
20
20
  import types
21
+ import typing
21
22
  from copy import copy as shallowcopy
22
23
  from pathlib import Path
23
24
  from struct import pack as struct_pack
@@ -34,7 +35,7 @@ import warp.config
34
35
 
35
36
 
36
37
  def create_value_func(type):
37
- def value_func(args, kwds, templates):
38
+ def value_func(arg_types, arg_values):
38
39
  return type
39
40
 
40
41
  return value_func
@@ -42,7 +43,7 @@ def create_value_func(type):
42
43
 
43
44
  def get_function_args(func):
44
45
  """Ensures that all function arguments are annotated and returns a dictionary mapping from argument name to its type."""
45
- argspec = inspect.getfullargspec(func)
46
+ argspec = warp.codegen.get_full_arg_spec(func)
46
47
 
47
48
  # use source-level argument annotations
48
49
  if len(argspec.annotations) < len(argspec.args):
@@ -63,7 +64,8 @@ class Function:
63
64
  input_types=None,
64
65
  value_type=None,
65
66
  value_func=None,
66
- template_func=None,
67
+ export_func=None,
68
+ dispatch_func=None,
67
69
  module=None,
68
70
  variadic=False,
69
71
  initializer_list_func=None,
@@ -97,14 +99,15 @@ class Function:
97
99
  self.namespace = namespace
98
100
  self.value_type = value_type
99
101
  self.value_func = value_func # a function that takes a list of args and a list of templates and returns the value type, e.g.: load(array, index) returns the type of value being loaded
100
- self.template_func = template_func
102
+ self.export_func = export_func
103
+ self.dispatch_func = dispatch_func
101
104
  self.input_types = {}
102
105
  self.export = export
103
106
  self.doc = doc
104
107
  self.group = group
105
108
  self.module = module
106
109
  self.variadic = variadic # function can take arbitrary number of inputs, e.g.: printf()
107
- self.defaults = defaults
110
+ self.defaults = {} if defaults is None else defaults
108
111
  # Function instance for a custom implementation of the replay pass
109
112
  self.custom_replay_func = custom_replay_func
110
113
  self.native_snippet = native_snippet
@@ -180,6 +183,33 @@ class Function:
180
183
  if not skip_adding_overload:
181
184
  self.add_overload(self)
182
185
 
186
+ # Store a description of the function's signature that can be used
187
+ # to resolve a bunch of positional/keyword/variadic arguments against,
188
+ # in a way that is compatible with Python's semantics.
189
+ signature_params = []
190
+ signature_default_param_kind = inspect.Parameter.POSITIONAL_OR_KEYWORD
191
+ for param_name in self.input_types.keys():
192
+ if param_name.startswith("**"):
193
+ param_name = param_name[2:]
194
+ param_kind = inspect.Parameter.VAR_KEYWORD
195
+ elif param_name.startswith("*"):
196
+ param_name = param_name[1:]
197
+ param_kind = inspect.Parameter.VAR_POSITIONAL
198
+
199
+ # Once a variadic argument like `*args` is found, any following
200
+ # arguments need to be passed using keywords.
201
+ signature_default_param_kind = inspect.Parameter.KEYWORD_ONLY
202
+ else:
203
+ param_kind = signature_default_param_kind
204
+
205
+ param = param = inspect.Parameter(
206
+ param_name,
207
+ param_kind,
208
+ default=self.defaults.get(param_name, inspect.Parameter.empty),
209
+ )
210
+ signature_params.append(param)
211
+ self.signature = inspect.Signature(signature_params)
212
+
183
213
  # add to current module
184
214
  if module:
185
215
  module.register_function(self, skip_adding_overload)
@@ -247,7 +277,7 @@ class Function:
247
277
 
248
278
  # only export simple types that don't use arrays
249
279
  for v in self.input_types.values():
250
- if isinstance(v, warp.array) or v in complex_type_hints:
280
+ if warp.types.is_array(v) or v in complex_type_hints:
251
281
  return False
252
282
 
253
283
  if type(self.value_type) in sequence_types:
@@ -261,8 +291,14 @@ class Function:
261
291
 
262
292
  name = "builtin_" + self.key
263
293
 
294
+ # Runtime arguments that are to be passed to the function, not its template signature.
295
+ if self.export_func is not None:
296
+ func_args = self.export_func(self.input_types)
297
+ else:
298
+ func_args = self.input_types
299
+
264
300
  types = []
265
- for t in self.input_types.values():
301
+ for t in func_args.values():
266
302
  types.append(t.__name__)
267
303
 
268
304
  return "_".join([name, *types])
@@ -299,7 +335,7 @@ class Function:
299
335
  )
300
336
  self.user_overloads[sig] = f
301
337
 
302
- def get_overload(self, arg_types):
338
+ def get_overload(self, arg_types, kwarg_types):
303
339
  assert not self.is_builtin()
304
340
 
305
341
  sig = warp.types.get_signature(arg_types, func_name=self.key)
@@ -347,15 +383,21 @@ class Function:
347
383
  def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
348
384
  uses_non_warp_array_type = False
349
385
 
350
- warp.context.init()
386
+ init()
351
387
 
352
388
  # Retrieve the built-in function from Warp's dll.
353
389
  c_func = getattr(warp.context.runtime.core, func.mangled_name)
354
390
 
391
+ # Runtime arguments that are to be passed to the function, not its template signature.
392
+ if func.export_func is not None:
393
+ func_args = func.export_func(func.input_types)
394
+ else:
395
+ func_args = func.input_types
396
+
355
397
  # Try gathering the parameters that the function expects and pack them
356
398
  # into their corresponding C types.
357
399
  c_params = []
358
- for i, (_, arg_type) in enumerate(func.input_types.items()):
400
+ for i, (_, arg_type) in enumerate(func_args.items()):
359
401
  param = params[i]
360
402
 
361
403
  try:
@@ -485,7 +527,8 @@ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
485
527
  c_params.append(arg_type._type_(param))
486
528
 
487
529
  # returns the corresponding ctype for a scalar or vector warp type
488
- value_type = func.value_func(None, None, None)
530
+ value_type = func.value_func(None, None)
531
+
489
532
  if value_type == float:
490
533
  value_ctype = ctypes.c_float
491
534
  elif value_type == int:
@@ -521,10 +564,12 @@ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
521
564
  return (True, ret)
522
565
 
523
566
  if value_type == warp.types.float16:
524
- return (True, warp.types.half_bits_to_float(ret.value))
567
+ value = warp.types.half_bits_to_float(ret.value)
568
+ else:
569
+ value = ret.value
525
570
 
526
571
  # return scalar types as int/float
527
- return (True, ret.value)
572
+ return (True, value)
528
573
 
529
574
 
530
575
  class KernelHooks:
@@ -742,7 +787,6 @@ def func_grad(forward_fn):
742
787
  input_types=reverse_args,
743
788
  value_func=None,
744
789
  module=f.module,
745
- template_func=f.template_func,
746
790
  skip_forward_codegen=True,
747
791
  custom_reverse_mode=True,
748
792
  custom_reverse_num_input_args=len(f.input_types),
@@ -807,7 +851,7 @@ def func_replay(forward_fn):
807
851
  f"Cannot define custom replay definition for {forward_fn.key} since the provided replay function has generic input arguments."
808
852
  )
809
853
 
810
- f = forward_fn.get_overload(arg_types)
854
+ f = forward_fn.get_overload(arg_types, {})
811
855
  if f is None:
812
856
  inputs_str = ", ".join([f"{k}: {v.__name__}" for k, v in args.items()])
813
857
  raise RuntimeError(
@@ -819,8 +863,9 @@ def func_replay(forward_fn):
819
863
  namespace=f.namespace,
820
864
  input_types=f.input_types,
821
865
  value_func=f.value_func,
866
+ export_func=f.export_func,
867
+ dispatch_func=f.dispatch_func,
822
868
  module=f.module,
823
- template_func=f.template_func,
824
869
  skip_reverse_codegen=True,
825
870
  skip_adding_overload=True,
826
871
  code_transformers=f.adj.transformers,
@@ -920,7 +965,7 @@ def overload(kernel, arg_types=None):
920
965
  )
921
966
 
922
967
  # ensure all arguments are annotated
923
- argspec = inspect.getfullargspec(fn)
968
+ argspec = warp.codegen.get_full_arg_spec(fn)
924
969
  if len(argspec.annotations) < len(argspec.args):
925
970
  raise RuntimeError(f"Incomplete argument annotations on kernel overload {fn.__name__}")
926
971
 
@@ -965,7 +1010,8 @@ def add_builtin(
965
1010
  constraint=None,
966
1011
  value_type=None,
967
1012
  value_func=None,
968
- template_func=None,
1013
+ export_func=None,
1014
+ dispatch_func=None,
969
1015
  doc="",
970
1016
  namespace="wp::",
971
1017
  variadic=False,
@@ -979,18 +1025,66 @@ def add_builtin(
979
1025
  defaults=None,
980
1026
  require_original_output_arg=False,
981
1027
  ):
1028
+ """Main entry point to register a new built-in function.
1029
+
1030
+ Args:
1031
+ key (str): Function name. Multiple overloaded functions can be registered
1032
+ under the same name as long as their signature differ.
1033
+ input_types (Mapping[str, Any]): Signature of the user-facing function.
1034
+ Variadic arguments are supported by prefixing the parameter names
1035
+ with asterisks as in `*args` and `**kwargs`. Generic arguments are
1036
+ supported with types such as `Any`, `Float`, `Scalar`, etc.
1037
+ constraint (Callable): For functions that define generic arguments and
1038
+ are to be exported, this callback is used to specify whether some
1039
+ combination of inferred arguments are valid or not.
1040
+ value_type (Any): Type returned by the function.
1041
+ value_func (Callable): Callback used to specify the return type when
1042
+ `value_type` isn't enough.
1043
+ export_func (Callable): Callback used during the context stage to specify
1044
+ the signature of the underlying C++ function, not accounting for
1045
+ the template parameters.
1046
+ If not provided, `input_types` is used.
1047
+ dispatch_func (Callable): Callback used during the codegen stage to specify
1048
+ the runtime and template arguments to be passed to the underlying C++
1049
+ function. In other words, this allows defining a mapping between
1050
+ the signatures of the user-facing and the C++ functions, and even to
1051
+ dynamically create new arguments on the fly.
1052
+ The arguments returned must be of type `codegen.Var`.
1053
+ If not provided, all arguments passed by the users when calling
1054
+ the built-in are passed as-is as runtime arguments to the C++ function.
1055
+ doc (str): Used to generate the Python's docstring and the HTML documentation.
1056
+ namespace: Namespace for the underlying C++ function.
1057
+ variadic (bool): Whether the function declares variadic arguments.
1058
+ initializer_list_func (bool): Whether to use the initializer list syntax
1059
+ when passing the arguments to the underlying C++ function.
1060
+ export (bool): Whether the function is to be exposed to the Python
1061
+ interpreter so that it becomes available from within the `warp`
1062
+ module.
1063
+ group (str): Classification used for the documentation.
1064
+ hidden (bool): Whether to add that function into the documentation.
1065
+ skip_replay (bool): Whether operation will be performed during
1066
+ the forward replay in the backward pass.
1067
+ missing_grad (bool): Whether the function is missing a corresponding
1068
+ adjoint.
1069
+ native_func (str): Name of the underlying C++ function.
1070
+ defaults (Mapping[str, Any]): Default values for the parameters defined
1071
+ in `input_types`.
1072
+ require_original_output_arg (bool): Used during the codegen stage to
1073
+ specify whether an adjoint parameter corresponding to the return
1074
+ value should be included in the signature of the backward function.
1075
+ """
982
1076
  if input_types is None:
983
1077
  input_types = {}
984
1078
 
985
1079
  # wrap simple single-type functions with a value_func()
986
1080
  if value_func is None:
987
1081
 
988
- def value_func(args, kwds, templates):
1082
+ def value_func(arg_types, arg_values):
989
1083
  return value_type
990
1084
 
991
1085
  if initializer_list_func is None:
992
1086
 
993
- def initializer_list_func(args, templates):
1087
+ def initializer_list_func(args, return_type):
994
1088
  return False
995
1089
 
996
1090
  if defaults is None:
@@ -998,8 +1092,13 @@ def add_builtin(
998
1092
 
999
1093
  # Add specialized versions of this builtin if it's generic by matching arguments against
1000
1094
  # hard coded types. We do this so you can use hard coded warp types outside kernels:
1095
+ if export_func is not None:
1096
+ func_arg_types = export_func(input_types)
1097
+ else:
1098
+ func_arg_types = input_types
1099
+
1001
1100
  generic = False
1002
- for x in input_types.values():
1101
+ for x in func_arg_types.values():
1003
1102
  if warp.types.type_is_generic(x):
1004
1103
  generic = True
1005
1104
  break
@@ -1007,7 +1106,7 @@ def add_builtin(
1007
1106
  if generic and export:
1008
1107
  # collect the parent type names of all the generic arguments:
1009
1108
  genericset = set()
1010
- for t in input_types.values():
1109
+ for t in func_arg_types.values():
1011
1110
  if hasattr(t, "_wp_generic_type_hint_"):
1012
1111
  genericset.add(t._wp_generic_type_hint_)
1013
1112
  elif warp.types.type_is_generic_scalar(t):
@@ -1059,15 +1158,17 @@ def add_builtin(
1059
1158
 
1060
1159
  typelists.append(l)
1061
1160
 
1062
- for argtypes in itertools.product(*typelists):
1161
+ for arg_types in itertools.product(*typelists):
1162
+ arg_types = dict(zip(input_types.keys(), arg_types))
1163
+
1063
1164
  # Some of these argument lists won't work, eg if the function is mul(), we won't be
1064
1165
  # able to do a matrix vector multiplication for a mat22 and a vec3. The `constraint`
1065
1166
  # function determines which combinations are valid:
1066
1167
  if constraint:
1067
- if constraint(argtypes) is False:
1168
+ if constraint(arg_types) is False:
1068
1169
  continue
1069
1170
 
1070
- return_type = value_func(argtypes, {}, [])
1171
+ return_type = value_func(arg_types, None)
1071
1172
 
1072
1173
  # The return_type might just be vector_t(length=3,dtype=wp.float32), so we've got to match that
1073
1174
  # in the list of hard coded types so it knows it's returning one of them:
@@ -1085,8 +1186,10 @@ def add_builtin(
1085
1186
  # finally we can generate a function call for these concrete types:
1086
1187
  add_builtin(
1087
1188
  key,
1088
- input_types=dict(zip(input_types.keys(), argtypes)),
1189
+ input_types=arg_types,
1089
1190
  value_type=return_type,
1191
+ export_func=export_func,
1192
+ dispatch_func=dispatch_func,
1090
1193
  doc=doc,
1091
1194
  namespace=namespace,
1092
1195
  variadic=variadic,
@@ -1096,6 +1199,7 @@ def add_builtin(
1096
1199
  hidden=True,
1097
1200
  skip_replay=skip_replay,
1098
1201
  missing_grad=missing_grad,
1202
+ defaults=defaults,
1099
1203
  require_original_output_arg=require_original_output_arg,
1100
1204
  )
1101
1205
 
@@ -1106,7 +1210,8 @@ def add_builtin(
1106
1210
  input_types=input_types,
1107
1211
  value_type=value_type,
1108
1212
  value_func=value_func,
1109
- template_func=template_func,
1213
+ export_func=export_func,
1214
+ dispatch_func=dispatch_func,
1110
1215
  variadic=variadic,
1111
1216
  initializer_list_func=initializer_list_func,
1112
1217
  export=export,
@@ -1250,7 +1355,7 @@ class ModuleBuilder:
1250
1355
  if not func.value_func:
1251
1356
 
1252
1357
  def wrap(adj):
1253
- def value_type(arg_types, kwds, templates):
1358
+ def value_type(arg_types, arg_values):
1254
1359
  if adj.return_var is None or len(adj.return_var) == 0:
1255
1360
  return None
1256
1361
  if len(adj.return_var) == 1:
@@ -1453,14 +1558,6 @@ class Module:
1453
1558
  computed ``content_hash`` will be used.
1454
1559
  """
1455
1560
 
1456
- def get_annotations(obj: Any) -> Mapping[str, Any]:
1457
- """Alternative to `inspect.get_annotations()` for Python 3.9 and older."""
1458
- # See https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
1459
- if isinstance(obj, type):
1460
- return obj.__dict__.get("__annotations__", {})
1461
-
1462
- return getattr(obj, "__annotations__", {})
1463
-
1464
1561
  def get_type_name(type_hint):
1465
1562
  if isinstance(type_hint, warp.codegen.Struct):
1466
1563
  return get_type_name(type_hint.cls)
@@ -1482,7 +1579,7 @@ class Module:
1482
1579
  for struct in module.structs.values():
1483
1580
  s = ",".join(
1484
1581
  "{}: {}".format(name, get_type_name(type_hint))
1485
- for name, type_hint in get_annotations(struct.cls).items()
1582
+ for name, type_hint in warp.codegen.get_annotations(struct.cls).items()
1486
1583
  )
1487
1584
  ch.update(bytes(s, "utf-8"))
1488
1585
 
@@ -1495,22 +1592,18 @@ class Module:
1495
1592
  ch.update(bytes(sig, "utf-8"))
1496
1593
 
1497
1594
  # source
1498
- s = func.adj.source
1499
- ch.update(bytes(s, "utf-8"))
1595
+ ch.update(bytes(func.adj.source, "utf-8"))
1500
1596
 
1501
1597
  if func.custom_grad_func:
1502
- s = func.custom_grad_func.adj.source
1503
- ch.update(bytes(s, "utf-8"))
1598
+ ch.update(bytes(func.custom_grad_func.adj.source, "utf-8"))
1504
1599
  if func.custom_replay_func:
1505
- s = func.custom_replay_func.adj.source
1600
+ ch.update(bytes(func.custom_replay_func.adj.source, "utf-8"))
1506
1601
  if func.replay_snippet:
1507
- s = func.replay_snippet
1602
+ ch.update(bytes(func.replay_snippet, "utf-8"))
1508
1603
  if func.native_snippet:
1509
- s = func.native_snippet
1510
- ch.update(bytes(s, "utf-8"))
1604
+ ch.update(bytes(func.native_snippet, "utf-8"))
1511
1605
  if func.adj_native_snippet:
1512
- s = func.adj_native_snippet
1513
- ch.update(bytes(s, "utf-8"))
1606
+ ch.update(bytes(func.adj_native_snippet, "utf-8"))
1514
1607
 
1515
1608
  # Populate constants referenced in this function
1516
1609
  if func.adj:
@@ -1621,7 +1714,7 @@ class Module:
1621
1714
 
1622
1715
  with ScopedTimer(
1623
1716
  f"Module {self.name} {module_hash.hex()[:7]} load on device '{device}'", active=not warp.config.quiet
1624
- ):
1717
+ ) as module_load_timer:
1625
1718
  # -----------------------------------------------------------
1626
1719
  # determine output paths
1627
1720
  if device.is_cpu:
@@ -1657,7 +1750,13 @@ class Module:
1657
1750
 
1658
1751
  build_dir = None
1659
1752
 
1660
- if not os.path.exists(binary_path) or not warp.config.cache_kernels:
1753
+ # we always want to build if binary doesn't exist yet
1754
+ # and we want to rebuild if we are not caching kernels or if we are tracking array access
1755
+ if (
1756
+ not os.path.exists(binary_path)
1757
+ or not warp.config.cache_kernels
1758
+ or warp.config.verify_autograd_array_access
1759
+ ):
1661
1760
  builder = ModuleBuilder(self, self.options)
1662
1761
 
1663
1762
  # create a temporary (process unique) dir for build outputs before moving to the binary dir
@@ -1668,6 +1767,8 @@ class Module:
1668
1767
  # dir may exist from previous attempts / runs / archs
1669
1768
  Path(build_dir).mkdir(parents=True, exist_ok=True)
1670
1769
 
1770
+ module_load_timer.extra_msg = " (compiled)" # For wp.ScopedTimer informational purposes
1771
+
1671
1772
  # build CPU
1672
1773
  if device.is_cpu:
1673
1774
  # build
@@ -1694,6 +1795,7 @@ class Module:
1694
1795
 
1695
1796
  except Exception as e:
1696
1797
  self.cpu_build_failed = True
1798
+ module_load_timer.extra_msg = " (error)"
1697
1799
  raise (e)
1698
1800
 
1699
1801
  elif device.is_cuda:
@@ -1722,6 +1824,7 @@ class Module:
1722
1824
 
1723
1825
  except Exception as e:
1724
1826
  self.cuda_build_failed = True
1827
+ module_load_timer.extra_msg = " (error)"
1725
1828
  raise (e)
1726
1829
 
1727
1830
  # -----------------------------------------------------------
@@ -1755,6 +1858,8 @@ class Module:
1755
1858
  except Exception as e:
1756
1859
  # We don't need source_code_path to be copied successfully to proceed, so warn and keep running
1757
1860
  warp.utils.warn(f"Exception when renaming {source_code_path}: {e}")
1861
+ else:
1862
+ module_load_timer.extra_msg = " (cached)" # For wp.ScopedTimer informational purposes
1758
1863
 
1759
1864
  # -----------------------------------------------------------
1760
1865
  # Load CPU or CUDA binary
@@ -1767,6 +1872,7 @@ class Module:
1767
1872
  if cuda_module is not None:
1768
1873
  self.cuda_modules[device.context] = cuda_module
1769
1874
  else:
1875
+ module_load_timer.extra_msg = " (error)"
1770
1876
  raise Exception(f"Failed to load CUDA module '{self.name}'")
1771
1877
 
1772
1878
  if build_dir:
@@ -1937,10 +2043,13 @@ class ContextGuard:
1937
2043
 
1938
2044
 
1939
2045
  class Stream:
1940
- def __init__(self, device=None, **kwargs):
1941
- self.cuda_stream = None
1942
- self.owner = False
2046
+ def __new__(cls, *args, **kwargs):
2047
+ instance = super(Stream, cls).__new__(cls)
2048
+ instance.cuda_stream = None
2049
+ instance.owner = False
2050
+ return instance
1943
2051
 
2052
+ def __init__(self, device=None, **kwargs):
1944
2053
  # event used internally for synchronization (cached to avoid creating temporary events)
1945
2054
  self._cached_event = None
1946
2055
 
@@ -2016,9 +2125,12 @@ class Event:
2016
2125
  BLOCKING_SYNC = 0x1
2017
2126
  DISABLE_TIMING = 0x2
2018
2127
 
2019
- def __init__(self, device=None, cuda_event=None, enable_timing=False):
2020
- self.owner = False
2128
+ def __new__(cls, *args, **kwargs):
2129
+ instance = super(Event, cls).__new__(cls)
2130
+ instance.owner = False
2131
+ return instance
2021
2132
 
2133
+ def __init__(self, device=None, cuda_event=None, enable_timing=False):
2022
2134
  device = get_device(device)
2023
2135
  if not device.is_cuda:
2024
2136
  raise RuntimeError(f"Device {device} is not a CUDA device")
@@ -2320,6 +2432,11 @@ Devicelike = Union[Device, str, None]
2320
2432
 
2321
2433
 
2322
2434
  class Graph:
2435
+ def __new__(cls, *args, **kwargs):
2436
+ instance = super(Graph, cls).__new__(cls)
2437
+ instance.exec = None
2438
+ return instance
2439
+
2323
2440
  def __init__(self, device: Device, exec: ctypes.c_void_p):
2324
2441
  self.device = device
2325
2442
  self.exec = exec
@@ -2682,48 +2799,38 @@ class Runtime:
2682
2799
  ctypes.c_void_p,
2683
2800
  ctypes.c_void_p,
2684
2801
  ctypes.c_int,
2685
- ctypes.c_float,
2686
- ctypes.c_float,
2687
- ctypes.c_float,
2688
- ctypes.c_float,
2689
- ctypes.c_float,
2802
+ ctypes.c_float * 9,
2803
+ ctypes.c_float * 3,
2690
2804
  ctypes.c_bool,
2805
+ ctypes.c_float,
2691
2806
  ]
2692
2807
  self.core.volume_f_from_tiles_device.restype = ctypes.c_uint64
2693
2808
  self.core.volume_v_from_tiles_device.argtypes = [
2694
2809
  ctypes.c_void_p,
2695
2810
  ctypes.c_void_p,
2696
2811
  ctypes.c_int,
2697
- ctypes.c_float,
2698
- ctypes.c_float,
2699
- ctypes.c_float,
2700
- ctypes.c_float,
2701
- ctypes.c_float,
2702
- ctypes.c_float,
2703
- ctypes.c_float,
2812
+ ctypes.c_float * 9,
2813
+ ctypes.c_float * 3,
2704
2814
  ctypes.c_bool,
2815
+ ctypes.c_float * 3,
2705
2816
  ]
2706
2817
  self.core.volume_v_from_tiles_device.restype = ctypes.c_uint64
2707
2818
  self.core.volume_i_from_tiles_device.argtypes = [
2708
2819
  ctypes.c_void_p,
2709
2820
  ctypes.c_void_p,
2710
2821
  ctypes.c_int,
2711
- ctypes.c_float,
2712
- ctypes.c_int,
2713
- ctypes.c_float,
2714
- ctypes.c_float,
2715
- ctypes.c_float,
2822
+ ctypes.c_float * 9,
2823
+ ctypes.c_float * 3,
2716
2824
  ctypes.c_bool,
2825
+ ctypes.c_int,
2717
2826
  ]
2718
2827
  self.core.volume_i_from_tiles_device.restype = ctypes.c_uint64
2719
2828
  self.core.volume_index_from_tiles_device.argtypes = [
2720
2829
  ctypes.c_void_p,
2721
2830
  ctypes.c_void_p,
2722
2831
  ctypes.c_int,
2723
- ctypes.c_float,
2724
- ctypes.c_float,
2725
- ctypes.c_float,
2726
- ctypes.c_float,
2832
+ ctypes.c_float * 9,
2833
+ ctypes.c_float * 3,
2727
2834
  ctypes.c_bool,
2728
2835
  ]
2729
2836
  self.core.volume_index_from_tiles_device.restype = ctypes.c_uint64
@@ -2731,10 +2838,8 @@ class Runtime:
2731
2838
  ctypes.c_void_p,
2732
2839
  ctypes.c_void_p,
2733
2840
  ctypes.c_int,
2734
- ctypes.c_float,
2735
- ctypes.c_float,
2736
- ctypes.c_float,
2737
- ctypes.c_float,
2841
+ ctypes.c_float * 9,
2842
+ ctypes.c_float * 3,
2738
2843
  ctypes.c_bool,
2739
2844
  ]
2740
2845
  self.core.volume_from_active_voxels_device.restype = ctypes.c_uint64
@@ -2780,39 +2885,38 @@ class Runtime:
2780
2885
  self.core.volume_get_blind_data_info.restype = ctypes.c_char_p
2781
2886
 
2782
2887
  bsr_matrix_from_triplets_argtypes = [
2783
- ctypes.c_int,
2784
- ctypes.c_int,
2785
- ctypes.c_int,
2786
- ctypes.c_int,
2787
- ctypes.c_uint64,
2788
- ctypes.c_uint64,
2789
- ctypes.c_uint64,
2790
- ctypes.c_uint64,
2791
- ctypes.c_uint64,
2792
- ctypes.c_uint64,
2888
+ ctypes.c_int, # rows_per_bock
2889
+ ctypes.c_int, # cols_per_blocks
2890
+ ctypes.c_int, # row_count
2891
+ ctypes.c_int, # tpl_nnz
2892
+ ctypes.POINTER(ctypes.c_int), # tpl_rows
2893
+ ctypes.POINTER(ctypes.c_int), # tpl_cols
2894
+ ctypes.c_void_p, # tpl_values
2895
+ ctypes.c_bool, # prune_numerical_zeros
2896
+ ctypes.POINTER(ctypes.c_int), # bsr_offsets
2897
+ ctypes.POINTER(ctypes.c_int), # bsr_columns
2898
+ ctypes.c_void_p, # bsr_values
2899
+ ctypes.POINTER(ctypes.c_int), # bsr_nnz
2900
+ ctypes.c_void_p, # bsr_nnz_event
2793
2901
  ]
2902
+
2794
2903
  self.core.bsr_matrix_from_triplets_float_host.argtypes = bsr_matrix_from_triplets_argtypes
2795
2904
  self.core.bsr_matrix_from_triplets_double_host.argtypes = bsr_matrix_from_triplets_argtypes
2796
2905
  self.core.bsr_matrix_from_triplets_float_device.argtypes = bsr_matrix_from_triplets_argtypes
2797
2906
  self.core.bsr_matrix_from_triplets_double_device.argtypes = bsr_matrix_from_triplets_argtypes
2798
2907
 
2799
- self.core.bsr_matrix_from_triplets_float_host.restype = ctypes.c_int
2800
- self.core.bsr_matrix_from_triplets_double_host.restype = ctypes.c_int
2801
- self.core.bsr_matrix_from_triplets_float_device.restype = ctypes.c_int
2802
- self.core.bsr_matrix_from_triplets_double_device.restype = ctypes.c_int
2803
-
2804
2908
  bsr_transpose_argtypes = [
2805
- ctypes.c_int,
2806
- ctypes.c_int,
2807
- ctypes.c_int,
2808
- ctypes.c_int,
2809
- ctypes.c_int,
2810
- ctypes.c_uint64,
2811
- ctypes.c_uint64,
2812
- ctypes.c_uint64,
2813
- ctypes.c_uint64,
2814
- ctypes.c_uint64,
2815
- ctypes.c_uint64,
2909
+ ctypes.c_int, # rows_per_bock
2910
+ ctypes.c_int, # cols_per_blocks
2911
+ ctypes.c_int, # row_count
2912
+ ctypes.c_int, # col count
2913
+ ctypes.c_int, # nnz
2914
+ ctypes.POINTER(ctypes.c_int), # transposed_bsr_offsets
2915
+ ctypes.POINTER(ctypes.c_int), # transposed_bsr_columns
2916
+ ctypes.c_void_p, # bsr_values
2917
+ ctypes.POINTER(ctypes.c_int), # transposed_bsr_offsets
2918
+ ctypes.POINTER(ctypes.c_int), # transposed_bsr_columns
2919
+ ctypes.c_void_p, # transposed_bsr_values
2816
2920
  ]
2817
2921
  self.core.bsr_transpose_float_host.argtypes = bsr_transpose_argtypes
2818
2922
  self.core.bsr_transpose_double_host.argtypes = bsr_transpose_argtypes
@@ -3019,35 +3123,63 @@ class Runtime:
3019
3123
  self.device_map["cpu"] = self.cpu_device
3020
3124
  self.context_map[None] = self.cpu_device
3021
3125
 
3022
- cuda_device_count = self.core.cuda_device_get_count()
3126
+ self.is_cuda_enabled = bool(self.core.is_cuda_enabled())
3127
+ self.is_cuda_compatibility_enabled = bool(self.core.is_cuda_compatibility_enabled())
3023
3128
 
3024
- if cuda_device_count > 0:
3129
+ self.toolkit_version = None # CTK version used to build the core lib
3130
+ self.driver_version = None # installed driver version
3131
+ self.min_driver_version = None # minimum required driver version
3132
+
3133
+ self.cuda_devices = []
3134
+ self.cuda_primary_devices = []
3135
+
3136
+ cuda_device_count = 0
3137
+
3138
+ if self.is_cuda_enabled:
3025
3139
  # get CUDA Toolkit and driver versions
3026
- self.toolkit_version = self.core.cuda_toolkit_version()
3027
- self.driver_version = self.core.cuda_driver_version()
3028
-
3029
- # get all architectures supported by NVRTC
3030
- num_archs = self.core.nvrtc_supported_arch_count()
3031
- if num_archs > 0:
3032
- archs = (ctypes.c_int * num_archs)()
3033
- self.core.nvrtc_supported_archs(archs)
3034
- self.nvrtc_supported_archs = list(archs)
3140
+ toolkit_version = self.core.cuda_toolkit_version()
3141
+ driver_version = self.core.cuda_driver_version()
3142
+
3143
+ # save versions as tuples, e.g., (12, 4)
3144
+ self.toolkit_version = (toolkit_version // 1000, (toolkit_version % 1000) // 10)
3145
+ self.driver_version = (driver_version // 1000, (driver_version % 1000) // 10)
3146
+
3147
+ # determine minimum required driver version
3148
+ if self.is_cuda_compatibility_enabled:
3149
+ # we can rely on minor version compatibility, but 11.4 is the absolute minimum required from the driver
3150
+ if self.toolkit_version[0] > 11:
3151
+ self.min_driver_version = (self.toolkit_version[0], 0)
3152
+ else:
3153
+ self.min_driver_version = (11, 4)
3035
3154
  else:
3036
- self.nvrtc_supported_archs = []
3155
+ # we can't rely on minor version compatibility, so the driver can't be older than the toolkit
3156
+ self.min_driver_version = self.toolkit_version
3157
+
3158
+ # determine if the installed driver is sufficient
3159
+ if self.driver_version >= self.min_driver_version:
3160
+ # get all architectures supported by NVRTC
3161
+ num_archs = self.core.nvrtc_supported_arch_count()
3162
+ if num_archs > 0:
3163
+ archs = (ctypes.c_int * num_archs)()
3164
+ self.core.nvrtc_supported_archs(archs)
3165
+ self.nvrtc_supported_archs = set(archs)
3166
+ else:
3167
+ self.nvrtc_supported_archs = set()
3037
3168
 
3038
- # this is so we can give non-primary contexts a reasonable alias
3039
- # associated with the physical device (e.g., "cuda:0.0", "cuda:0.1")
3040
- self.cuda_custom_context_count = [0] * cuda_device_count
3169
+ # get CUDA device count
3170
+ cuda_device_count = self.core.cuda_device_get_count()
3041
3171
 
3042
- # register primary CUDA devices
3043
- self.cuda_devices = []
3044
- self.cuda_primary_devices = []
3045
- for i in range(cuda_device_count):
3046
- alias = f"cuda:{i}"
3047
- device = Device(self, alias, ordinal=i, is_primary=True)
3048
- self.cuda_devices.append(device)
3049
- self.cuda_primary_devices.append(device)
3050
- self.device_map[alias] = device
3172
+ # register primary CUDA devices
3173
+ for i in range(cuda_device_count):
3174
+ alias = f"cuda:{i}"
3175
+ device = Device(self, alias, ordinal=i, is_primary=True)
3176
+ self.cuda_devices.append(device)
3177
+ self.cuda_primary_devices.append(device)
3178
+ self.device_map[alias] = device
3179
+
3180
+ # count known non-primary contexts on each physical device so we can
3181
+ # give them reasonable aliases (e.g., "cuda:0.0", "cuda:0.1")
3182
+ self.cuda_custom_context_count = [0] * cuda_device_count
3051
3183
 
3052
3184
  # set default device
3053
3185
  if cuda_device_count > 0:
@@ -3066,14 +3198,8 @@ class Runtime:
3066
3198
  # initialize kernel cache
3067
3199
  warp.build.init_kernel_cache(warp.config.kernel_cache_dir)
3068
3200
 
3069
- devices_without_uva = []
3070
- devices_without_mempool = []
3071
- for cuda_device in self.cuda_devices:
3072
- if cuda_device.is_primary:
3073
- if not cuda_device.is_uva:
3074
- devices_without_uva.append(cuda_device)
3075
- if not cuda_device.is_mempool_supported:
3076
- devices_without_mempool.append(cuda_device)
3201
+ # global tape
3202
+ self.tape = None
3077
3203
 
3078
3204
  # print device and version information
3079
3205
  if not warp.config.quiet:
@@ -3081,18 +3207,24 @@ class Runtime:
3081
3207
 
3082
3208
  greeting.append(f"Warp {warp.config.version} initialized:")
3083
3209
  if cuda_device_count > 0:
3084
- toolkit_version = (self.toolkit_version // 1000, (self.toolkit_version % 1000) // 10)
3085
- driver_version = (self.driver_version // 1000, (self.driver_version % 1000) // 10)
3210
+ # print CUDA version info
3086
3211
  greeting.append(
3087
- f" CUDA Toolkit {toolkit_version[0]}.{toolkit_version[1]}, Driver {driver_version[0]}.{driver_version[1]}"
3212
+ f" CUDA Toolkit {self.toolkit_version[0]}.{self.toolkit_version[1]}, Driver {self.driver_version[0]}.{self.driver_version[1]}"
3088
3213
  )
3089
3214
  else:
3090
- if self.core.is_cuda_enabled():
3091
- # Warp was compiled with CUDA support, but no devices are available
3092
- greeting.append(" CUDA devices not available")
3093
- else:
3215
+ # briefly explain lack of CUDA devices
3216
+ if not self.is_cuda_enabled:
3094
3217
  # Warp was compiled without CUDA support
3095
- greeting.append(" CUDA support not enabled in this build")
3218
+ greeting.append(" CUDA not enabled in this build")
3219
+ elif self.driver_version < self.min_driver_version:
3220
+ # insufficient CUDA driver version
3221
+ greeting.append(
3222
+ f" CUDA Toolkit {self.toolkit_version[0]}.{self.toolkit_version[1]}, Driver {self.driver_version[0]}.{self.driver_version[1]}"
3223
+ " (insufficient CUDA driver version!)"
3224
+ )
3225
+ else:
3226
+ # CUDA is supported, but no devices are available
3227
+ greeting.append(" CUDA devices not available")
3096
3228
  greeting.append(" Devices:")
3097
3229
  alias_str = f'"{self.cpu_device.alias}"'
3098
3230
  name_str = f'"{self.cpu_device.name}"'
@@ -3151,41 +3283,44 @@ class Runtime:
3151
3283
  print("\n".join(greeting))
3152
3284
 
3153
3285
  if cuda_device_count > 0:
3154
- # warn about possible misconfiguration of the system
3286
+ # ensure initialization did not change the initial context (e.g. querying available memory)
3287
+ self.core.cuda_context_set_current(initial_context)
3288
+
3289
+ # detect possible misconfiguration of the system
3290
+ devices_without_uva = []
3291
+ devices_without_mempool = []
3292
+ for cuda_device in self.cuda_primary_devices:
3293
+ if not cuda_device.is_uva:
3294
+ devices_without_uva.append(cuda_device)
3295
+ if not cuda_device.is_mempool_supported:
3296
+ devices_without_mempool.append(cuda_device)
3297
+
3155
3298
  if devices_without_uva:
3156
3299
  # This should not happen on any system officially supported by Warp. UVA is not available
3157
3300
  # on 32-bit Windows, which we don't support. Nonetheless, we should check and report a
3158
3301
  # warning out of abundance of caution. It may help with debugging a broken VM setup etc.
3159
3302
  warp.utils.warn(
3160
- f"Support for Unified Virtual Addressing (UVA) was not detected on devices {devices_without_uva}."
3303
+ f"\n Support for Unified Virtual Addressing (UVA) was not detected on devices {devices_without_uva}."
3161
3304
  )
3162
3305
  if devices_without_mempool:
3163
3306
  warp.utils.warn(
3164
- f"Support for CUDA memory pools was not detected on devices {devices_without_mempool}. "
3165
- "This prevents memory allocations in CUDA graphs and may result in poor performance. "
3166
- "Is the UVM driver enabled?"
3307
+ f"\n Support for CUDA memory pools was not detected on devices {devices_without_mempool}."
3308
+ "\n This prevents memory allocations in CUDA graphs and may result in poor performance."
3309
+ "\n Is the UVM driver enabled?"
3167
3310
  )
3168
3311
 
3169
- # CUDA compatibility check. This should only affect developer builds done with the
3170
- # --quick flag. The consequences of running with an older driver can be obscure and severe,
3171
- # so make sure we print a very visible warning.
3172
- if self.driver_version < self.toolkit_version and not self.core.is_cuda_compatibility_enabled():
3173
- print(
3174
- "******************************************************************\n"
3175
- "* WARNING: *\n"
3176
- "* Warp was compiled without CUDA compatibility support *\n"
3177
- "* (quick build). The CUDA Toolkit version used to build *\n"
3178
- "* Warp is not fully supported by the current driver. *\n"
3179
- "* Some CUDA functionality may not work correctly! *\n"
3180
- "* Update the driver or rebuild Warp without the --quick flag. *\n"
3181
- "******************************************************************\n"
3312
+ elif self.is_cuda_enabled:
3313
+ # Report a warning about insufficient driver version. The warning should appear even in quiet mode
3314
+ # when the greeting message is suppressed. Also try to provide guidance for resolving the situation.
3315
+ if self.driver_version < self.min_driver_version:
3316
+ msg = []
3317
+ msg.append("\n Insufficient CUDA driver version.")
3318
+ msg.append(
3319
+ f"The minimum required CUDA driver version is {self.min_driver_version[0]}.{self.min_driver_version[1]}, "
3320
+ f"but the installed CUDA driver version is {self.driver_version[0]}.{self.driver_version[1]}."
3182
3321
  )
3183
-
3184
- # ensure initialization did not change the initial context (e.g. querying available memory)
3185
- self.core.cuda_context_set_current(initial_context)
3186
-
3187
- # global tape
3188
- self.tape = None
3322
+ msg.append("Visit https://github.com/NVIDIA/warp/blob/main/README.md#installing for guidance.")
3323
+ warp.utils.warn("\n ".join(msg))
3189
3324
 
3190
3325
  def get_error_string(self):
3191
3326
  return self.core.get_error_string().decode("utf-8")
@@ -3208,17 +3343,20 @@ class Runtime:
3208
3343
  return dll
3209
3344
 
3210
3345
  def get_device(self, ident: Devicelike = None) -> Device:
3211
- if isinstance(ident, Device):
3346
+ # special cases
3347
+ if type(ident) is Device:
3212
3348
  return ident
3213
3349
  elif ident is None:
3214
3350
  return self.default_device
3215
- elif isinstance(ident, str):
3216
- if ident == "cuda":
3217
- return self.get_current_cuda_device()
3218
- else:
3219
- return self.device_map[ident]
3220
- else:
3221
- raise RuntimeError(f"Unable to resolve device from argument of type {type(ident)}")
3351
+
3352
+ # string lookup
3353
+ device = self.device_map.get(ident)
3354
+ if device is not None:
3355
+ return device
3356
+ elif ident == "cuda":
3357
+ return self.get_current_cuda_device()
3358
+
3359
+ raise ValueError(f"Invalid device identifier: {ident}")
3222
3360
 
3223
3361
  def set_default_device(self, ident: Devicelike):
3224
3362
  self.default_device = self.get_device(ident)
@@ -3248,7 +3386,7 @@ class Runtime:
3248
3386
  return self.cuda_devices[0]
3249
3387
  else:
3250
3388
  # CUDA is not available
3251
- if not self.core.is_cuda_enabled():
3389
+ if not self.is_cuda_enabled:
3252
3390
  raise RuntimeError('"cuda" device requested but this build of Warp does not support CUDA')
3253
3391
  else:
3254
3392
  raise RuntimeError('"cuda" device requested but CUDA is not supported by the hardware or driver')
@@ -3821,6 +3959,11 @@ class RegisteredGLBuffer:
3821
3959
 
3822
3960
  __fallback_warning_shown = False
3823
3961
 
3962
+ def __new__(cls, *args, **kwargs):
3963
+ instance = super(RegisteredGLBuffer, cls).__new__(cls)
3964
+ instance.resource = None
3965
+ return instance
3966
+
3824
3967
  def __init__(self, gl_buffer_id: int, device: Devicelike = None, flags: int = NONE, fallback_to_copy: bool = True):
3825
3968
  """
3826
3969
  Args:
@@ -4230,6 +4373,10 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
4230
4373
  # allow for NULL arrays
4231
4374
  return arg_type.__ctype__()
4232
4375
 
4376
+ elif isinstance(value, warp.types.array_t):
4377
+ # accept array descriptors verbatum
4378
+ return value
4379
+
4233
4380
  else:
4234
4381
  # check for array type
4235
4382
  # - in forward passes, array types have to match
@@ -4240,6 +4387,32 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
4240
4387
  array_matches = type(value) is type(arg_type)
4241
4388
 
4242
4389
  if not array_matches:
4390
+ # if a regular Warp array is required, try converting from __cuda_array_interface__ or __array_interface__
4391
+ if isinstance(arg_type, warp.array):
4392
+ if device.is_cuda:
4393
+ # check for __cuda_array_interface__
4394
+ try:
4395
+ interface = value.__cuda_array_interface__
4396
+ except AttributeError:
4397
+ pass
4398
+ else:
4399
+ return warp.types.array_ctype_from_interface(interface, dtype=arg_type.dtype, owner=value)
4400
+ else:
4401
+ # check for __array_interface__
4402
+ try:
4403
+ interface = value.__array_interface__
4404
+ except AttributeError:
4405
+ pass
4406
+ else:
4407
+ return warp.types.array_ctype_from_interface(interface, dtype=arg_type.dtype, owner=value)
4408
+ # check for __array__() method, e.g. Torch CPU tensors
4409
+ try:
4410
+ interface = value.__array__().__array_interface__
4411
+ except AttributeError:
4412
+ pass
4413
+ else:
4414
+ return warp.types.array_ctype_from_interface(interface, dtype=arg_type.dtype, owner=value)
4415
+
4243
4416
  adj = "adjoint " if adjoint else ""
4244
4417
  raise RuntimeError(
4245
4418
  f"Error launching kernel '{kernel.key}', {adj}argument '{arg_name}' expects an array of type {type(arg_type)}, but passed value has type {type(value)}."
@@ -4603,6 +4776,10 @@ def launch(
4603
4776
  caller = {"file": frame.f_code.co_filename, "lineno": frame.f_lineno, "func": frame.f_code.co_name}
4604
4777
  runtime.tape.record_launch(kernel, dim, max_blocks, inputs, outputs, device, metadata={"caller": caller})
4605
4778
 
4779
+ # detect illegal inter-kernel read/write access patterns if verification flag is set
4780
+ if warp.config.verify_autograd_array_access:
4781
+ runtime.tape._check_kernel_array_access(kernel, fwd_args)
4782
+
4606
4783
 
4607
4784
  def synchronize():
4608
4785
  """Manually synchronize the calling CPU thread with any outstanding CUDA work on all devices
@@ -4808,7 +4985,7 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=None
4808
4985
  """
4809
4986
 
4810
4987
  if force_module_load is None:
4811
- if runtime.driver_version >= 12030:
4988
+ if runtime.driver_version >= (12, 3):
4812
4989
  # Driver versions 12.3 and can compile modules during graph capture
4813
4990
  force_module_load = False
4814
4991
  else:
@@ -5084,6 +5261,9 @@ def copy(
5084
5261
  ),
5085
5262
  arrays=[dest, src],
5086
5263
  )
5264
+ if warp.config.verify_autograd_array_access:
5265
+ dest.mark_write()
5266
+ src.mark_read()
5087
5267
 
5088
5268
 
5089
5269
  def adj_copy(
@@ -5106,8 +5286,16 @@ def type_str(t):
5106
5286
  return "Any"
5107
5287
  elif t == Callable:
5108
5288
  return "Callable"
5289
+ elif t == Tuple[int]:
5290
+ return "Tuple[int]"
5109
5291
  elif t == Tuple[int, int]:
5110
5292
  return "Tuple[int, int]"
5293
+ elif t == Tuple[int, int, int]:
5294
+ return "Tuple[int, int, int]"
5295
+ elif t == Tuple[int, int, int, int]:
5296
+ return "Tuple[int, int, int, int]"
5297
+ elif t == Tuple[int, ...]:
5298
+ return "Tuple[int, ...]"
5111
5299
  elif isinstance(t, int):
5112
5300
  return str(t)
5113
5301
  elif isinstance(t, List):
@@ -5142,6 +5330,9 @@ def type_str(t):
5142
5330
  return f"Transformation[{type_str(t._wp_scalar_type_)}]"
5143
5331
 
5144
5332
  raise TypeError("Invalid vector or matrix dimensions")
5333
+ elif typing.get_origin(t) in (List, Mapping, Sequence, Union, Tuple):
5334
+ args_repr = ", ".join(type_str(x) for x in typing.get_args(t))
5335
+ return f"{t.__name__}[{args_repr}]"
5145
5336
 
5146
5337
  return t.__name__
5147
5338
 
@@ -5169,7 +5360,7 @@ def print_function(f, file, noentry=False): # pragma: no cover
5169
5360
  try:
5170
5361
  # todo: construct a default value for each of the functions args
5171
5362
  # so we can generate the return type for overloaded functions
5172
- return_type = " -> " + type_str(f.value_func(None, None, None))
5363
+ return_type = " -> " + type_str(f.value_func(None, None))
5173
5364
  except Exception:
5174
5365
  pass
5175
5366
 
@@ -5232,14 +5423,6 @@ def export_functions_rst(file): # pragma: no cover
5232
5423
  print(".. class:: Transformation", file=file)
5233
5424
  print(".. class:: Array", file=file)
5234
5425
 
5235
- print("\nQuery Types", file=file)
5236
- print("-------------", file=file)
5237
- print(".. autoclass:: bvh_query_t", file=file)
5238
- print(".. autoclass:: hash_grid_query_t", file=file)
5239
- print(".. autoclass:: mesh_query_aabb_t", file=file)
5240
- print(".. autoclass:: mesh_query_point_t", file=file)
5241
- print(".. autoclass:: mesh_query_ray_t", file=file)
5242
-
5243
5426
  # build dictionary of all functions by group
5244
5427
  groups = {}
5245
5428
 
@@ -5252,8 +5435,17 @@ def export_functions_rst(file): # pragma: no cover
5252
5435
  for o in f.overloads:
5253
5436
  groups[f.group].append(o)
5254
5437
 
5255
- # Keep track of what function names have been written
5256
- written_functions = {}
5438
+ # Keep track of what function and query types have been written
5439
+ written_functions = set()
5440
+ written_query_types = set()
5441
+
5442
+ query_types = (
5443
+ ("bvh_query", "BvhQuery"),
5444
+ ("mesh_query_aabb", "MeshQueryAABB"),
5445
+ ("mesh_query_point", "MeshQueryPoint"),
5446
+ ("mesh_query_ray", "MeshQueryRay"),
5447
+ ("hash_grid_query", "HashGridQuery"),
5448
+ )
5257
5449
 
5258
5450
  for k, g in groups.items():
5259
5451
  print("\n", file=file)
@@ -5261,12 +5453,18 @@ def export_functions_rst(file): # pragma: no cover
5261
5453
  print("---------------", file=file)
5262
5454
 
5263
5455
  for f in g:
5456
+ for f_prefix, query_type in query_types:
5457
+ if f.key.startswith(f_prefix) and query_type not in written_query_types:
5458
+ print(f".. autoclass:: {query_type}", file=file)
5459
+ written_query_types.add(query_type)
5460
+ break
5461
+
5264
5462
  if f.key in written_functions:
5265
5463
  # Add :noindex: + :nocontentsentry: since Sphinx gets confused
5266
5464
  print_function(f, file=file, noentry=True)
5267
5465
  else:
5268
5466
  if print_function(f, file=file):
5269
- written_functions[f.key] = []
5467
+ written_functions.add(f.key)
5270
5468
 
5271
5469
  # footnotes
5272
5470
  print(".. rubric:: Footnotes", file=file)
@@ -5327,7 +5525,7 @@ def export_stubs(file): # pragma: no cover
5327
5525
  try:
5328
5526
  # todo: construct a default value for each of the functions args
5329
5527
  # so we can generate the return type for overloaded functions
5330
- return_type = f.value_func(None, None, None)
5528
+ return_type = f.value_func(None, None)
5331
5529
  if return_type:
5332
5530
  return_str = " -> " + type_str(return_type)
5333
5531
 
@@ -5373,21 +5571,25 @@ def export_builtins(file: io.TextIOBase): # pragma: no cover
5373
5571
  if not f.is_simple():
5374
5572
  continue
5375
5573
 
5376
- args = ", ".join(f"{ctype_arg_str(v)} {k}" for k, v in f.input_types.items())
5377
- params = ", ".join(f.input_types.keys())
5378
-
5379
- return_type = ""
5380
-
5381
5574
  try:
5382
5575
  # todo: construct a default value for each of the functions args
5383
5576
  # so we can generate the return type for overloaded functions
5384
- return_type = ctype_ret_str(f.value_func(None, None, None))
5577
+ return_type = ctype_ret_str(f.value_func(None, None))
5385
5578
  except Exception:
5386
5579
  continue
5387
5580
 
5388
5581
  if return_type.startswith("Tuple"):
5389
5582
  continue
5390
5583
 
5584
+ # Runtime arguments that are to be passed to the function, not its template signature.
5585
+ if f.export_func is not None:
5586
+ func_args = f.export_func(f.input_types)
5587
+ else:
5588
+ func_args = f.input_types
5589
+
5590
+ args = ", ".join(f"{ctype_arg_str(v)} {k}" for k, v in func_args.items())
5591
+ params = ", ".join(func_args.keys())
5592
+
5391
5593
  if args == "":
5392
5594
  file.write(f"WP_API void {f.mangled_name}({return_type}* ret) {{ *ret = wp::{f.key}({params}); }}\n")
5393
5595
  elif return_type == "None":