warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (271) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.dll +0 -0
  57. warp/bin/warp.dll +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/bin/warp-clang.so +0 -0
  257. warp/bin/warp.so +0 -0
  258. warp/fem/field/discrete_field.py +0 -80
  259. warp/fem/space/nodal_function_space.py +0 -233
  260. warp/tests/test_all.py +0 -223
  261. warp/tests/test_array_scan.py +0 -60
  262. warp/tests/test_base.py +0 -208
  263. warp/tests/test_unresolved_func.py +0 -7
  264. warp/tests/test_unresolved_symbol.py +0 -7
  265. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  266. warp_lang-1.0.0b2.dist-info/RECORD +0 -380
  267. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  268. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  269. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  270. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  271. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -7,8 +7,10 @@
7
7
 
8
8
  import ast
9
9
  import ctypes
10
+ import gc
10
11
  import hashlib
11
12
  import inspect
13
+ import io
12
14
  import os
13
15
  import platform
14
16
  import sys
@@ -68,6 +70,8 @@ class Function:
68
70
  native_func=None,
69
71
  defaults=None,
70
72
  custom_replay_func=None,
73
+ native_snippet=None,
74
+ adj_native_snippet=None,
71
75
  skip_forward_codegen=False,
72
76
  skip_reverse_codegen=False,
73
77
  custom_reverse_num_input_args=-1,
@@ -75,6 +79,7 @@ class Function:
75
79
  overloaded_annotations=None,
76
80
  code_transformers=[],
77
81
  skip_adding_overload=False,
82
+ require_original_output_arg=False,
78
83
  ):
79
84
  self.func = func # points to Python function decorated with @wp.func, may be None for builtins
80
85
  self.key = key
@@ -90,7 +95,10 @@ class Function:
90
95
  self.defaults = defaults
91
96
  # Function instance for a custom implementation of the replay pass
92
97
  self.custom_replay_func = custom_replay_func
98
+ self.native_snippet = native_snippet
99
+ self.adj_native_snippet = adj_native_snippet
93
100
  self.custom_grad_func = None
101
+ self.require_original_output_arg = require_original_output_arg
94
102
 
95
103
  if initializer_list_func is None:
96
104
  self.initializer_list_func = lambda x, y: False
@@ -170,121 +178,24 @@ class Function:
170
178
  # from within a kernel (experimental).
171
179
 
172
180
  if self.is_builtin() and self.mangled_name:
173
- # store last error during overload resolution
174
- error = None
175
-
176
- for f in self.overloads:
177
- if f.generic:
181
+ # For each of this function's existing overloads, we attempt to pack
182
+ # the given arguments into the C types expected by the corresponding
183
+ # parameters, and we rinse and repeat until we get a match.
184
+ for overload in self.overloads:
185
+ if overload.generic:
178
186
  continue
179
187
 
180
- # try and find builtin in the warp.dll
181
- if not hasattr(warp.context.runtime.core, f.mangled_name):
182
- raise RuntimeError(
183
- f"Couldn't find function {self.key} with mangled name {f.mangled_name} in the Warp native library"
184
- )
185
-
186
- try:
187
- # try and pack args into what the function expects
188
- params = []
189
- for i, (arg_name, arg_type) in enumerate(f.input_types.items()):
190
- a = args[i]
191
-
192
- # try to convert to a value type (vec3, mat33, etc)
193
- if issubclass(arg_type, ctypes.Array):
194
- # wrap the arg_type (which is an ctypes.Array) in a structure
195
- # to ensure parameter is passed to the .dll by value rather than reference
196
- class ValueArg(ctypes.Structure):
197
- _fields_ = [("value", arg_type)]
198
-
199
- x = ValueArg()
200
-
201
- # force conversion to ndarray first (handles tuple / list, Gf.Vec3 case)
202
- if isinstance(a, ctypes.Array) is False:
203
- # assume you want the float32 version of the function so it doesn't just
204
- # grab an override for a random data type:
205
- if arg_type._type_ != ctypes.c_float:
206
- raise RuntimeError(
207
- f"Error calling function '{f.key}', parameter for argument '{arg_name}' does not have c_float type."
208
- )
209
-
210
- a = np.array(a)
211
-
212
- # flatten to 1D array
213
- v = a.flatten()
214
- if len(v) != arg_type._length_:
215
- raise RuntimeError(
216
- f"Error calling function '{f.key}', parameter for argument '{arg_name}' has length {len(v)}, but expected {arg_type._length_}. Could not convert parameter to {arg_type}."
217
- )
218
-
219
- for i in range(arg_type._length_):
220
- x.value[i] = v[i]
221
-
222
- else:
223
- # already a built-in type, check it matches
224
- if not warp.types.types_equal(type(a), arg_type):
225
- raise RuntimeError(
226
- f"Error calling function '{f.key}', parameter for argument '{arg_name}' has type '{type(a)}' but expected '{arg_type}'"
227
- )
228
-
229
- x.value = a
230
-
231
- params.append(x)
232
-
233
- else:
234
- try:
235
- # try to pack as a scalar type
236
- params.append(arg_type._type_(a))
237
- except Exception:
238
- raise RuntimeError(
239
- f"Error calling function {f.key}, unable to pack function parameter type {type(a)} for param {arg_name}, expected {arg_type}"
240
- )
241
-
242
- # returns the corresponding ctype for a scalar or vector warp type
243
- def type_ctype(dtype):
244
- if dtype == float:
245
- return ctypes.c_float
246
- elif dtype == int:
247
- return ctypes.c_int32
248
- elif issubclass(dtype, ctypes.Array):
249
- return dtype
250
- elif issubclass(dtype, ctypes.Structure):
251
- return dtype
252
- else:
253
- # scalar type
254
- return dtype._type_
255
-
256
- value_type = type_ctype(f.value_func(None, None, None))
257
-
258
- # construct return value (passed by address)
259
- ret = value_type()
260
- ret_addr = ctypes.c_void_p(ctypes.addressof(ret))
261
-
262
- params.append(ret_addr)
263
-
264
- c_func = getattr(warp.context.runtime.core, f.mangled_name)
265
- c_func(*params)
266
-
267
- if issubclass(value_type, ctypes.Array) or issubclass(value_type, ctypes.Structure):
268
- # return vector types as ctypes
269
- return ret
270
- else:
271
- # return scalar types as int/float
272
- return ret.value
273
-
274
- except Exception as e:
275
- # couldn't pack values to match this overload
276
- # store error and move onto the next one
277
- error = e
278
- continue
188
+ success, return_value = call_builtin(overload, *args)
189
+ if success:
190
+ return return_value
279
191
 
280
192
  # overload resolution or call failed
281
- # raise the last exception encountered
282
- if error:
283
- raise error
284
- else:
285
- raise RuntimeError(f"Error calling function '{f.key}'.")
193
+ raise RuntimeError(
194
+ f"Couldn't find a function '{self.key}' compatible with "
195
+ f"the arguments '{', '.join(type(x).__name__ for x in args)}'"
196
+ )
286
197
 
287
- elif hasattr(self, "user_overloads") and len(self.user_overloads):
198
+ if hasattr(self, "user_overloads") and len(self.user_overloads):
288
199
  # user-defined function with overloads
289
200
 
290
201
  if len(kwargs):
@@ -293,28 +204,26 @@ class Function:
293
204
  )
294
205
 
295
206
  # try and find a matching overload
296
- for f in self.user_overloads.values():
297
- if len(f.input_types) != len(args):
207
+ for overload in self.user_overloads.values():
208
+ if len(overload.input_types) != len(args):
298
209
  continue
299
- template_types = list(f.input_types.values())
300
- arg_names = list(f.input_types.keys())
210
+ template_types = list(overload.input_types.values())
211
+ arg_names = list(overload.input_types.keys())
301
212
  try:
302
213
  # attempt to unify argument types with function template types
303
214
  warp.types.infer_argument_types(args, template_types, arg_names)
304
- return f.func(*args)
215
+ return overload.func(*args)
305
216
  except Exception:
306
217
  continue
307
218
 
308
219
  raise RuntimeError(f"Error calling function '{self.key}', no overload found for arguments {args}")
309
220
 
310
- else:
311
- # user-defined function with no overloads
312
-
313
- if self.func is None:
314
- raise RuntimeError(f"Error calling function '{self.key}', function is undefined")
221
+ # user-defined function with no overloads
222
+ if self.func is None:
223
+ raise RuntimeError(f"Error calling function '{self.key}', function is undefined")
315
224
 
316
- # this function has no overloads, call it like a plain Python function
317
- return self.func(*args, **kwargs)
225
+ # this function has no overloads, call it like a plain Python function
226
+ return self.func(*args, **kwargs)
318
227
 
319
228
  def is_builtin(self):
320
229
  return self.func is None
@@ -427,10 +336,188 @@ class Function:
427
336
  return None
428
337
 
429
338
  def __repr__(self):
430
- inputs_str = ", ".join([f"{k}: {v.__name__}" for k, v in self.input_types.items()])
339
+ inputs_str = ", ".join([f"{k}: {warp.types.type_repr(v)}" for k, v in self.input_types.items()])
431
340
  return f"<Function {self.key}({inputs_str})>"
432
341
 
433
342
 
343
+ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
344
+ uses_non_warp_array_type = False
345
+
346
+ # Retrieve the built-in function from Warp's dll.
347
+ c_func = getattr(warp.context.runtime.core, func.mangled_name)
348
+
349
+ # Try gathering the parameters that the function expects and pack them
350
+ # into their corresponding C types.
351
+ c_params = []
352
+ for i, (_, arg_type) in enumerate(func.input_types.items()):
353
+ param = params[i]
354
+
355
+ try:
356
+ iter(param)
357
+ except TypeError:
358
+ is_array = False
359
+ else:
360
+ is_array = True
361
+
362
+ if is_array:
363
+ if not issubclass(arg_type, ctypes.Array):
364
+ return (False, None)
365
+
366
+ # The argument expects a built-in Warp type like a vector or a matrix.
367
+
368
+ c_param = None
369
+
370
+ if isinstance(param, ctypes.Array):
371
+ # The given parameter is also a built-in Warp type, so we only need
372
+ # to make sure that it matches with the argument.
373
+ if not warp.types.types_equal(type(param), arg_type):
374
+ return (False, None)
375
+
376
+ if isinstance(param, arg_type):
377
+ c_param = param
378
+ else:
379
+ # Cast the value to its argument type to make sure that it
380
+ # can be assigned to the field of the `Param` struct.
381
+ # This could error otherwise when, for example, the field type
382
+ # is set to `vec3i` while the value is of type `vector(length=3, dtype=int)`,
383
+ # even though both types are semantically identical.
384
+ c_param = arg_type(param)
385
+ else:
386
+ # Flatten the parameter values into a flat 1-D array.
387
+ arr = []
388
+ ndim = 1
389
+ stack = [(0, param)]
390
+ while stack:
391
+ depth, elem = stack.pop(0)
392
+ try:
393
+ # If `elem` is a sequence, then it should be possible
394
+ # to add its elements to the stack for later processing.
395
+ stack.extend((depth + 1, x) for x in elem)
396
+ except TypeError:
397
+ # Since `elem` doesn't seem to be a sequence,
398
+ # we must have a leaf value that we need to add to our
399
+ # resulting array.
400
+ arr.append(elem)
401
+ ndim = max(depth, ndim)
402
+
403
+ assert ndim > 0
404
+
405
+ # Ensure that if the given parameter value is, say, a 2-D array,
406
+ # then we try to resolve it against a matrix argument rather than
407
+ # a vector.
408
+ if ndim > len(arg_type._shape_):
409
+ return (False, None)
410
+
411
+ elem_count = len(arr)
412
+ if elem_count != arg_type._length_:
413
+ return (False, None)
414
+
415
+ # Retrieve the element type of the sequence while ensuring
416
+ # that it's homogeneous.
417
+ elem_type = type(arr[0])
418
+ for i in range(1, elem_count):
419
+ if type(arr[i]) is not elem_type:
420
+ raise ValueError("All array elements must share the same type.")
421
+
422
+ expected_elem_type = arg_type._wp_scalar_type_
423
+ if not (
424
+ elem_type is expected_elem_type
425
+ or (elem_type is float and expected_elem_type is warp.types.float32)
426
+ or (elem_type is int and expected_elem_type is warp.types.int32)
427
+ or (
428
+ issubclass(elem_type, np.number)
429
+ and warp.types.np_dtype_to_warp_type[np.dtype(elem_type)] is expected_elem_type
430
+ )
431
+ ):
432
+ # The parameter value has a type not matching the type defined
433
+ # for the corresponding argument.
434
+ return (False, None)
435
+
436
+ if elem_type in warp.types.int_types:
437
+ # Pass the value through the expected integer type
438
+ # in order to evaluate any integer wrapping.
439
+ # For example `uint8(-1)` should result in the value `-255`.
440
+ arr = tuple(elem_type._type_(x.value).value for x in arr)
441
+ elif elem_type in warp.types.float_types:
442
+ # Extract the floating-point values.
443
+ arr = tuple(x.value for x in arr)
444
+
445
+ c_param = arg_type()
446
+ if warp.types.type_is_matrix(arg_type):
447
+ rows, cols = arg_type._shape_
448
+ for i in range(rows):
449
+ idx_start = i * cols
450
+ idx_end = idx_start + cols
451
+ c_param[i] = arr[idx_start:idx_end]
452
+ else:
453
+ c_param[:] = arr
454
+
455
+ uses_non_warp_array_type = True
456
+
457
+ c_params.append(ctypes.byref(c_param))
458
+ else:
459
+ if issubclass(arg_type, ctypes.Array):
460
+ return (False, None)
461
+
462
+ if not (
463
+ isinstance(param, arg_type)
464
+ or (type(param) is float and arg_type is warp.types.float32)
465
+ or (type(param) is int and arg_type is warp.types.int32)
466
+ or warp.types.np_dtype_to_warp_type.get(getattr(param, "dtype", None)) is arg_type
467
+ ):
468
+ return (False, None)
469
+
470
+ if type(param) in warp.types.scalar_types:
471
+ param = param.value
472
+
473
+ # try to pack as a scalar type
474
+ if arg_type == warp.types.float16:
475
+ c_params.append(arg_type._type_(warp.types.float_to_half_bits(param)))
476
+ else:
477
+ c_params.append(arg_type._type_(param))
478
+
479
+ # returns the corresponding ctype for a scalar or vector warp type
480
+ value_type = func.value_func(None, None, None)
481
+ if value_type == float:
482
+ value_ctype = ctypes.c_float
483
+ elif value_type == int:
484
+ value_ctype = ctypes.c_int32
485
+ elif issubclass(value_type, (ctypes.Array, ctypes.Structure)):
486
+ value_ctype = value_type
487
+ else:
488
+ # scalar type
489
+ value_ctype = value_type._type_
490
+
491
+ # construct return value (passed by address)
492
+ ret = value_ctype()
493
+ ret_addr = ctypes.c_void_p(ctypes.addressof(ret))
494
+ c_params.append(ret_addr)
495
+
496
+ # Call the built-in function from Warp's dll.
497
+ c_func(*c_params)
498
+
499
+ # TODO: uncomment when we have a way to print warning messages only once.
500
+ # if uses_non_warp_array_type:
501
+ # warp.utils.warn(
502
+ # "Support for built-in functions called with non-Warp array types, "
503
+ # "such as lists, tuples, NumPy arrays, and others, will be dropped "
504
+ # "in the future. Use a Warp type such as `wp.vec`, `wp.mat`, "
505
+ # "`wp.quat`, or `wp.transform`.",
506
+ # DeprecationWarning,
507
+ # stacklevel=3
508
+ # )
509
+
510
+ if issubclass(value_ctype, ctypes.Array) or issubclass(value_ctype, ctypes.Structure):
511
+ # return vector types as ctypes
512
+ return (True, ret)
513
+
514
+ if value_type == warp.types.float16:
515
+ return (True, warp.types.half_bits_to_float(ret.value))
516
+
517
+ # return scalar types as int/float
518
+ return (True, ret.value)
519
+
520
+
434
521
  class KernelHooks:
435
522
  def __init__(self, forward, backward):
436
523
  self.forward = forward
@@ -439,10 +526,20 @@ class KernelHooks:
439
526
 
440
527
  # caches source and compiled entry points for a kernel (will be populated after module loads)
441
528
  class Kernel:
442
- def __init__(self, func, key, module, options=None, code_transformers=[]):
529
+ def __init__(self, func, key=None, module=None, options=None, code_transformers=[]):
443
530
  self.func = func
444
- self.module = module
445
- self.key = key
531
+
532
+ if module is None:
533
+ self.module = get_module(func.__module__)
534
+ else:
535
+ self.module = module
536
+
537
+ if key is None:
538
+ unique_key = self.module.generate_unique_kernel_key(func.__name__)
539
+ self.key = unique_key
540
+ else:
541
+ self.key = key
542
+
446
543
  self.options = {} if options is None else options
447
544
 
448
545
  self.adj = warp.codegen.Adjoint(func, transformers=code_transformers)
@@ -463,8 +560,8 @@ class Kernel:
463
560
  # argument indices by name
464
561
  self.arg_indices = dict((a.label, i) for i, a in enumerate(self.adj.args))
465
562
 
466
- if module:
467
- module.register_kernel(self)
563
+ if self.module:
564
+ self.module.register_kernel(self)
468
565
 
469
566
  def infer_argument_types(self, args):
470
567
  template_types = list(self.adj.arg_types.values())
@@ -541,7 +638,7 @@ def func(f):
541
638
  name = warp.codegen.make_full_qualified_name(f)
542
639
 
543
640
  m = get_module(f.__module__)
544
- func = Function(
641
+ Function(
545
642
  func=f, key=name, namespace="", module=m, value_func=None
546
643
  ) # value_type not known yet, will be inferred during Adjoint.build()
547
644
 
@@ -549,6 +646,24 @@ def func(f):
549
646
  return m.functions[name]
550
647
 
551
648
 
649
+ def func_native(snippet, adj_snippet=None):
650
+ """
651
+ Decorator to register native code snippet, @func_native
652
+ """
653
+
654
+ def snippet_func(f):
655
+ name = warp.codegen.make_full_qualified_name(f)
656
+
657
+ m = get_module(f.__module__)
658
+ func = Function(
659
+ func=f, key=name, namespace="", module=m, native_snippet=snippet, adj_native_snippet=adj_snippet
660
+ ) # cuda snippets do not have a return value_type
661
+
662
+ return m.functions[name]
663
+
664
+ return snippet_func
665
+
666
+
552
667
  def func_grad(forward_fn):
553
668
  """
554
669
  Decorator to register a custom gradient function for a given forward function.
@@ -819,6 +934,7 @@ def add_builtin(
819
934
  missing_grad=False,
820
935
  native_func=None,
821
936
  defaults=None,
937
+ require_original_output_arg=False,
822
938
  ):
823
939
  # wrap simple single-type functions with a value_func()
824
940
  if value_func is None:
@@ -912,7 +1028,7 @@ def add_builtin(
912
1028
  # on the generated argument list and skip generation if it fails.
913
1029
  # This also gives us the return type, which we keep for later:
914
1030
  try:
915
- return_type = value_func([warp.codegen.Var("", t) for t in argtypes], {}, [])
1031
+ return_type = value_func(argtypes, {}, [])
916
1032
  except Exception:
917
1033
  continue
918
1034
 
@@ -943,6 +1059,7 @@ def add_builtin(
943
1059
  hidden=True,
944
1060
  skip_replay=skip_replay,
945
1061
  missing_grad=missing_grad,
1062
+ require_original_output_arg=require_original_output_arg,
946
1063
  )
947
1064
 
948
1065
  func = Function(
@@ -963,6 +1080,7 @@ def add_builtin(
963
1080
  generic=generic,
964
1081
  native_func=native_func,
965
1082
  defaults=defaults,
1083
+ require_original_output_arg=require_original_output_arg,
966
1084
  )
967
1085
 
968
1086
  if key in builtin_functions:
@@ -972,7 +1090,7 @@ def add_builtin(
972
1090
 
973
1091
  # export means the function will be added to the `warp` module namespace
974
1092
  # so that users can call it directly from the Python interpreter
975
- if export is True:
1093
+ if export:
976
1094
  if hasattr(warp, key):
977
1095
  # check that we haven't already created something at this location
978
1096
  # if it's just an overload stub for auto-complete then overwrite it
@@ -1057,8 +1175,7 @@ class ModuleBuilder:
1057
1175
  while stack:
1058
1176
  s = stack.pop()
1059
1177
 
1060
- if s not in structs:
1061
- structs.append(s)
1178
+ structs.append(s)
1062
1179
 
1063
1180
  for var in s.vars.values():
1064
1181
  if isinstance(var.type, warp.codegen.Struct):
@@ -1090,7 +1207,7 @@ class ModuleBuilder:
1090
1207
  if not func.value_func:
1091
1208
 
1092
1209
  def wrap(adj):
1093
- def value_type(args, kwds, templates):
1210
+ def value_type(arg_types, kwds, templates):
1094
1211
  if adj.return_var is None or len(adj.return_var) == 0:
1095
1212
  return None
1096
1213
  if len(adj.return_var) == 1:
@@ -1114,9 +1231,14 @@ class ModuleBuilder:
1114
1231
 
1115
1232
  # code-gen all imported functions
1116
1233
  for func in self.functions.keys():
1117
- source += warp.codegen.codegen_func(
1118
- func.adj, c_func_name=func.native_func, device=device, options=self.options
1119
- )
1234
+ if func.native_snippet is None:
1235
+ source += warp.codegen.codegen_func(
1236
+ func.adj, c_func_name=func.native_func, device=device, options=self.options
1237
+ )
1238
+ else:
1239
+ source += warp.codegen.codegen_snippet(
1240
+ func.adj, name=func.key, snippet=func.native_snippet, adj_snippet=func.adj_native_snippet
1241
+ )
1120
1242
 
1121
1243
  for kernel in self.module.kernels.values():
1122
1244
  # each kernel gets an entry point in the module
@@ -1196,6 +1318,10 @@ class Module:
1196
1318
 
1197
1319
  self.content_hash = None
1198
1320
 
1321
+ # number of times module auto-generates kernel key for user
1322
+ # used to ensure unique kernel keys
1323
+ self.count = 0
1324
+
1199
1325
  def register_struct(self, struct):
1200
1326
  self.structs[struct.key] = struct
1201
1327
 
@@ -1238,6 +1364,11 @@ class Module:
1238
1364
  # for a reload of module on next launch
1239
1365
  self.unload()
1240
1366
 
1367
+ def generate_unique_kernel_key(self, key):
1368
+ unique_key = f"{key}_{self.count}"
1369
+ self.count += 1
1370
+ return unique_key
1371
+
1241
1372
  # collect all referenced functions / structs
1242
1373
  # given the AST of a function or kernel
1243
1374
  def find_references(self, adj):
@@ -1251,7 +1382,7 @@ class Module:
1251
1382
  if isinstance(node, ast.Call):
1252
1383
  try:
1253
1384
  # try to resolve the function
1254
- func, _ = adj.resolve_path(node.func)
1385
+ func, _ = adj.resolve_static_expression(node.func, eval_types=False)
1255
1386
 
1256
1387
  # if this is a user-defined function, add a module reference
1257
1388
  if isinstance(func, warp.context.Function) and func.module is not None:
@@ -1304,9 +1435,24 @@ class Module:
1304
1435
  s = func.adj.source
1305
1436
  ch.update(bytes(s, "utf-8"))
1306
1437
 
1438
+ if func.custom_grad_func:
1439
+ s = func.custom_grad_func.adj.source
1440
+ ch.update(bytes(s, "utf-8"))
1441
+ if func.custom_replay_func:
1442
+ s = func.custom_replay_func.adj.source
1443
+
1444
+ # cache func arg types
1445
+ for arg, arg_type in func.adj.arg_types.items():
1446
+ s = f"{arg}: {get_type_name(arg_type)}"
1447
+ ch.update(bytes(s, "utf-8"))
1448
+
1307
1449
  # kernel source
1308
1450
  for kernel in module.kernels.values():
1309
1451
  ch.update(bytes(kernel.adj.source, "utf-8"))
1452
+ # cache kernel arg types
1453
+ for arg, arg_type in kernel.adj.arg_types.items():
1454
+ s = f"{arg}: {get_type_name(arg_type)}"
1455
+ ch.update(bytes(s, "utf-8"))
1310
1456
  # for generic kernels the Python source is always the same,
1311
1457
  # but we hash the type signatures of all the overloads
1312
1458
  if kernel.is_generic:
@@ -1605,13 +1751,13 @@ class ContextGuard:
1605
1751
  def __enter__(self):
1606
1752
  if self.device.is_cuda:
1607
1753
  runtime.core.cuda_context_push_current(self.device.context)
1608
- elif is_cuda_available():
1754
+ elif is_cuda_driver_initialized():
1609
1755
  self.saved_context = runtime.core.cuda_context_get_current()
1610
1756
 
1611
1757
  def __exit__(self, exc_type, exc_value, traceback):
1612
1758
  if self.device.is_cuda:
1613
1759
  runtime.core.cuda_context_pop_current()
1614
- elif is_cuda_available():
1760
+ elif is_cuda_driver_initialized():
1615
1761
  runtime.core.cuda_context_set_current(self.saved_context)
1616
1762
 
1617
1763
 
@@ -1896,7 +2042,7 @@ class Runtime:
1896
2042
 
1897
2043
  self.core = self.load_dll(warp_lib)
1898
2044
 
1899
- if llvm_lib and os.path.exists(llvm_lib):
2045
+ if os.path.exists(llvm_lib):
1900
2046
  self.llvm = self.load_dll(llvm_lib)
1901
2047
  # setup c-types for warp-clang.dll
1902
2048
  self.llvm.lookup.restype = ctypes.c_uint64
@@ -2262,6 +2408,8 @@ class Runtime:
2262
2408
  self.core.cuda_driver_version.restype = ctypes.c_int
2263
2409
  self.core.cuda_toolkit_version.argtypes = None
2264
2410
  self.core.cuda_toolkit_version.restype = ctypes.c_int
2411
+ self.core.cuda_driver_is_initialized.argtypes = None
2412
+ self.core.cuda_driver_is_initialized.restype = ctypes.c_bool
2265
2413
 
2266
2414
  self.core.nvrtc_supported_arch_count.argtypes = None
2267
2415
  self.core.nvrtc_supported_arch_count.restype = ctypes.c_int
@@ -2364,6 +2512,7 @@ class Runtime:
2364
2512
  ctypes.c_void_p,
2365
2513
  ctypes.c_void_p,
2366
2514
  ctypes.c_size_t,
2515
+ ctypes.c_int,
2367
2516
  ctypes.POINTER(ctypes.c_void_p),
2368
2517
  ]
2369
2518
  self.core.cuda_launch_kernel.restype = ctypes.c_size_t
@@ -2484,8 +2633,15 @@ class Runtime:
2484
2633
  dll = ctypes.CDLL(dll_path, winmode=0)
2485
2634
  else:
2486
2635
  dll = ctypes.CDLL(dll_path)
2487
- except OSError:
2488
- raise RuntimeError(f"Failed to load the shared library '{dll_path}'")
2636
+ except OSError as e:
2637
+ if "GLIBCXX" in str(e):
2638
+ raise RuntimeError(
2639
+ f"Failed to load the shared library '{dll_path}'.\n"
2640
+ "The execution environment's libstdc++ runtime is older than the version the Warp library was built for.\n"
2641
+ "See https://nvidia.github.io/warp/_build/html/installation.html#conda-environments for details."
2642
+ ) from e
2643
+ else:
2644
+ raise RuntimeError(f"Failed to load the shared library '{dll_path}'") from e
2489
2645
  return dll
2490
2646
 
2491
2647
  def get_device(self, ident: Devicelike = None) -> Device:
@@ -2614,6 +2770,21 @@ def is_device_available(device):
2614
2770
  return device in get_devices()
2615
2771
 
2616
2772
 
2773
+ def is_cuda_driver_initialized() -> bool:
2774
+ """Returns ``True`` if the CUDA driver is initialized.
2775
+
2776
+ This is a stricter test than ``is_cuda_available()`` since a CUDA driver
2777
+ call to ``cuCtxGetCurrent`` is made, and the result is compared to
2778
+ `CUDA_SUCCESS`. Note that `CUDA_SUCCESS` is returned by ``cuCtxGetCurrent``
2779
+ even if there is no context bound to the calling CPU thread.
2780
+
2781
+ This can be helpful in cases in which ``cuInit()`` was called before a fork.
2782
+ """
2783
+ assert_initialized()
2784
+
2785
+ return runtime.core.cuda_driver_is_initialized()
2786
+
2787
+
2617
2788
  def get_devices() -> List[Device]:
2618
2789
  """Returns a list of devices supported in this environment."""
2619
2790
 
@@ -3090,9 +3261,9 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
3090
3261
  # - in forward passes, array types have to match
3091
3262
  # - in backward passes, indexed array gradients are regular arrays
3092
3263
  if adjoint:
3093
- array_matches = type(value) == warp.array
3264
+ array_matches = isinstance(value, warp.array)
3094
3265
  else:
3095
- array_matches = type(value) == type(arg_type)
3266
+ array_matches = type(value) is type(arg_type)
3096
3267
 
3097
3268
  if not array_matches:
3098
3269
  adj = "adjoint " if adjoint else ""
@@ -3172,7 +3343,7 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
3172
3343
  # represents all data required for a kernel launch
3173
3344
  # so that launches can be replayed quickly, use `wp.launch(..., record_cmd=True)`
3174
3345
  class Launch:
3175
- def __init__(self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None):
3346
+ def __init__(self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None, max_blocks=0):
3176
3347
  # if not specified look up hooks
3177
3348
  if not hooks:
3178
3349
  module = kernel.module
@@ -3209,6 +3380,7 @@ class Launch:
3209
3380
  self.params_addr = params_addr
3210
3381
  self.device = device
3211
3382
  self.bounds = bounds
3383
+ self.max_blocks = max_blocks
3212
3384
 
3213
3385
  def set_dim(self, dim):
3214
3386
  self.bounds = warp.types.launch_bounds_t(dim)
@@ -3274,7 +3446,9 @@ class Launch:
3274
3446
  if self.device.is_cpu:
3275
3447
  self.hooks.forward(*self.params)
3276
3448
  else:
3277
- runtime.core.cuda_launch_kernel(self.device.context, self.hooks.forward, self.bounds.size, self.params_addr)
3449
+ runtime.core.cuda_launch_kernel(
3450
+ self.device.context, self.hooks.forward, self.bounds.size, self.max_blocks, self.params_addr
3451
+ )
3278
3452
 
3279
3453
 
3280
3454
  def launch(
@@ -3289,6 +3463,7 @@ def launch(
3289
3463
  adjoint=False,
3290
3464
  record_tape=True,
3291
3465
  record_cmd=False,
3466
+ max_blocks=0,
3292
3467
  ):
3293
3468
  """Launch a Warp kernel on the target device
3294
3469
 
@@ -3306,6 +3481,8 @@ def launch(
3306
3481
  adjoint: Whether to run forward or backward pass (typically use False)
3307
3482
  record_tape: When true the launch will be recorded the global wp.Tape() object when present
3308
3483
  record_cmd: When True the launch will be returned as a ``Launch`` command object, the launch will not occur until the user calls ``cmd.launch()``
3484
+ max_blocks: The maximum number of CUDA thread blocks to use. Only has an effect for CUDA kernel launches.
3485
+ If negative or zero, the maximum hardware value will be used.
3309
3486
  """
3310
3487
 
3311
3488
  assert_initialized()
@@ -3317,7 +3494,7 @@ def launch(
3317
3494
  device = runtime.get_device(device)
3318
3495
 
3319
3496
  # check function is a Kernel
3320
- if isinstance(kernel, Kernel) is False:
3497
+ if not isinstance(kernel, Kernel):
3321
3498
  raise RuntimeError("Error launching kernel, can only launch functions decorated with @wp.kernel.")
3322
3499
 
3323
3500
  # debugging aid
@@ -3399,7 +3576,9 @@ def launch(
3399
3576
  f"Failed to find backward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
3400
3577
  )
3401
3578
 
3402
- runtime.core.cuda_launch_kernel(device.context, hooks.backward, bounds.size, kernel_params)
3579
+ runtime.core.cuda_launch_kernel(
3580
+ device.context, hooks.backward, bounds.size, max_blocks, kernel_params
3581
+ )
3403
3582
 
3404
3583
  else:
3405
3584
  if hooks.forward is None:
@@ -3420,7 +3599,9 @@ def launch(
3420
3599
 
3421
3600
  else:
3422
3601
  # launch
3423
- runtime.core.cuda_launch_kernel(device.context, hooks.forward, bounds.size, kernel_params)
3602
+ runtime.core.cuda_launch_kernel(
3603
+ device.context, hooks.forward, bounds.size, max_blocks, kernel_params
3604
+ )
3424
3605
 
3425
3606
  try:
3426
3607
  runtime.verify_cuda_device(device)
@@ -3430,7 +3611,7 @@ def launch(
3430
3611
 
3431
3612
  # record on tape if one is active
3432
3613
  if runtime.tape and record_tape:
3433
- runtime.tape.record_launch(kernel, dim, inputs, outputs, device)
3614
+ runtime.tape.record_launch(kernel, dim, max_blocks, inputs, outputs, device)
3434
3615
 
3435
3616
 
3436
3617
  def synchronize():
@@ -3440,7 +3621,7 @@ def synchronize():
3440
3621
  or memory copies have completed.
3441
3622
  """
3442
3623
 
3443
- if is_cuda_available():
3624
+ if is_cuda_driver_initialized():
3444
3625
  # save the original context to avoid side effects
3445
3626
  saved_context = runtime.core.cuda_context_get_current()
3446
3627
 
@@ -3490,7 +3671,7 @@ def synchronize_stream(stream_or_device=None):
3490
3671
  runtime.core.cuda_stream_synchronize(stream.device.context, stream.cuda_stream)
3491
3672
 
3492
3673
 
3493
- def force_load(device: Union[Device, str] = None, modules: List[Module] = None):
3674
+ def force_load(device: Union[Device, str, List[Device], List[str]] = None, modules: List[Module] = None):
3494
3675
  """Force user-defined kernels to be compiled and loaded
3495
3676
 
3496
3677
  Args:
@@ -3498,12 +3679,14 @@ def force_load(device: Union[Device, str] = None, modules: List[Module] = None):
3498
3679
  modules: List of modules to load. If None, load all imported modules.
3499
3680
  """
3500
3681
 
3501
- if is_cuda_available():
3682
+ if is_cuda_driver_initialized():
3502
3683
  # save original context to avoid side effects
3503
3684
  saved_context = runtime.core.cuda_context_get_current()
3504
3685
 
3505
3686
  if device is None:
3506
3687
  devices = get_devices()
3688
+ elif isinstance(device, list):
3689
+ devices = [get_device(device_item) for device_item in device]
3507
3690
  else:
3508
3691
  devices = [get_device(device)]
3509
3692
 
@@ -3595,7 +3778,7 @@ def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
3595
3778
  return get_module(m.__name__).options
3596
3779
 
3597
3780
 
3598
- def capture_begin(device: Devicelike = None, stream=None, force_module_load=True):
3781
+ def capture_begin(device: Devicelike = None, stream=None, force_module_load=None):
3599
3782
  """Begin capture of a CUDA graph
3600
3783
 
3601
3784
  Captures all subsequent kernel launches and memory operations on CUDA devices.
@@ -3609,7 +3792,10 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=True
3609
3792
 
3610
3793
  """
3611
3794
 
3612
- if warp.config.verify_cuda is True:
3795
+ if force_module_load is None:
3796
+ force_module_load = warp.config.graph_capture_module_load_default
3797
+
3798
+ if warp.config.verify_cuda:
3613
3799
  raise RuntimeError("Cannot use CUDA error verification during graph capture")
3614
3800
 
3615
3801
  if stream is not None:
@@ -3624,6 +3810,9 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=True
3624
3810
 
3625
3811
  device.is_capturing = True
3626
3812
 
3813
+ # disable garbage collection to avoid older allocations getting collected during graph capture
3814
+ gc.disable()
3815
+
3627
3816
  with warp.ScopedStream(stream):
3628
3817
  runtime.core.cuda_graph_begin_capture(device.context)
3629
3818
 
@@ -3647,6 +3836,9 @@ def capture_end(device: Devicelike = None, stream=None) -> Graph:
3647
3836
 
3648
3837
  device.is_capturing = False
3649
3838
 
3839
+ # re-enable GC
3840
+ gc.enable()
3841
+
3650
3842
  if graph is None:
3651
3843
  raise RuntimeError(
3652
3844
  "Error occurred during CUDA graph capture. This could be due to an unintended allocation or CPU/GPU synchronization event."
@@ -3841,7 +4033,7 @@ def type_str(t):
3841
4033
  return t.__name__
3842
4034
 
3843
4035
 
3844
- def print_function(f, file, noentry=False):
4036
+ def print_function(f, file, noentry=False): # pragma: no cover
3845
4037
  """Writes a function definition to a file for use in reST documentation
3846
4038
 
3847
4039
  Args:
@@ -3886,7 +4078,7 @@ def print_function(f, file, noentry=False):
3886
4078
  return True
3887
4079
 
3888
4080
 
3889
- def print_builtins(file):
4081
+ def export_functions_rst(file): # pragma: no cover
3890
4082
  header = (
3891
4083
  "..\n"
3892
4084
  " Autogenerated File - Do not edit. Run build_docs.py to generate.\n"
@@ -3906,6 +4098,8 @@ def print_builtins(file):
3906
4098
 
3907
4099
  for t in warp.types.scalar_types:
3908
4100
  print(f".. class:: {t.__name__}", file=file)
4101
+ # Manually add wp.bool since it's inconvenient to add to wp.types.scalar_types:
4102
+ print(f".. class:: {warp.types.bool.__name__}", file=file)
3909
4103
 
3910
4104
  print("\n\nVector Types", file=file)
3911
4105
  print("------------", file=file)
@@ -3925,6 +4119,14 @@ def print_builtins(file):
3925
4119
  print(".. class:: Transformation", file=file)
3926
4120
  print(".. class:: Array", file=file)
3927
4121
 
4122
+ print("\nQuery Types", file=file)
4123
+ print("-------------", file=file)
4124
+ print(".. autoclass:: bvh_query_t", file=file)
4125
+ print(".. autoclass:: hash_grid_query_t", file=file)
4126
+ print(".. autoclass:: mesh_query_aabb_t", file=file)
4127
+ print(".. autoclass:: mesh_query_point_t", file=file)
4128
+ print(".. autoclass:: mesh_query_ray_t", file=file)
4129
+
3928
4130
  # build dictionary of all functions by group
3929
4131
  groups = {}
3930
4132
 
@@ -3958,7 +4160,7 @@ def print_builtins(file):
3958
4160
  print(".. [1] Note: function gradients not implemented for backpropagation.", file=file)
3959
4161
 
3960
4162
 
3961
- def export_stubs(file):
4163
+ def export_stubs(file): # pragma: no cover
3962
4164
  """Generates stub file for auto-complete of builtin functions"""
3963
4165
 
3964
4166
  import textwrap
@@ -3990,6 +4192,8 @@ def export_stubs(file):
3990
4192
  print("Quaternion = Generic[Float]", file=file)
3991
4193
  print("Transformation = Generic[Float]", file=file)
3992
4194
  print("Array = Generic[DType]", file=file)
4195
+ print("FabricArray = Generic[DType]", file=file)
4196
+ print("IndexedFabricArray = Generic[DType]", file=file)
3993
4197
 
3994
4198
  # prepend __init__.py
3995
4199
  with open(os.path.join(os.path.dirname(file.name), "__init__.py")) as header_file:
@@ -4006,7 +4210,7 @@ def export_stubs(file):
4006
4210
 
4007
4211
  return_str = ""
4008
4212
 
4009
- if f.export is False or f.hidden is True: # or f.generic:
4213
+ if not f.export or f.hidden: # or f.generic:
4010
4214
  continue
4011
4215
 
4012
4216
  try:
@@ -4027,8 +4231,18 @@ def export_stubs(file):
4027
4231
  print(" ...\n\n", file=file)
4028
4232
 
4029
4233
 
4030
- def export_builtins(file):
4031
- def ctype_str(t):
4234
+ def export_builtins(file: io.TextIOBase): # pragma: no cover
4235
+ def ctype_arg_str(t):
4236
+ if isinstance(t, int):
4237
+ return "int"
4238
+ elif isinstance(t, float):
4239
+ return "float"
4240
+ elif t in warp.types.vector_types:
4241
+ return f"{t.__name__}&"
4242
+ else:
4243
+ return t.__name__
4244
+
4245
+ def ctype_ret_str(t):
4032
4246
  if isinstance(t, int):
4033
4247
  return "int"
4034
4248
  elif isinstance(t, float):
@@ -4036,9 +4250,12 @@ def export_builtins(file):
4036
4250
  else:
4037
4251
  return t.__name__
4038
4252
 
4253
+ file.write("namespace wp {\n\n")
4254
+ file.write('extern "C" {\n\n')
4255
+
4039
4256
  for k, g in builtin_functions.items():
4040
4257
  for f in g.overloads:
4041
- if f.export is False or f.generic:
4258
+ if not f.export or f.generic:
4042
4259
  continue
4043
4260
 
4044
4261
  simple = True
@@ -4052,7 +4269,7 @@ def export_builtins(file):
4052
4269
  if not simple or f.variadic:
4053
4270
  continue
4054
4271
 
4055
- args = ", ".join(f"{ctype_str(v)} {k}" for k, v in f.input_types.items())
4272
+ args = ", ".join(f"{ctype_arg_str(v)} {k}" for k, v in f.input_types.items())
4056
4273
  params = ", ".join(f.input_types.keys())
4057
4274
 
4058
4275
  return_type = ""
@@ -4060,7 +4277,7 @@ def export_builtins(file):
4060
4277
  try:
4061
4278
  # todo: construct a default value for each of the functions args
4062
4279
  # so we can generate the return type for overloaded functions
4063
- return_type = ctype_str(f.value_func(None, None, None))
4280
+ return_type = ctype_ret_str(f.value_func(None, None, None))
4064
4281
  except Exception:
4065
4282
  continue
4066
4283
 
@@ -4068,17 +4285,17 @@ def export_builtins(file):
4068
4285
  continue
4069
4286
 
4070
4287
  if args == "":
4071
- print(
4072
- f"WP_API void {f.mangled_name}({return_type}* ret) {{ *ret = wp::{f.key}({params}); }}", file=file
4073
- )
4288
+ file.write(f"WP_API void {f.mangled_name}({return_type}* ret) {{ *ret = wp::{f.key}({params}); }}\n")
4074
4289
  elif return_type == "None":
4075
- print(f"WP_API void {f.mangled_name}({args}) {{ wp::{f.key}({params}); }}", file=file)
4290
+ file.write(f"WP_API void {f.mangled_name}({args}) {{ wp::{f.key}({params}); }}\n")
4076
4291
  else:
4077
- print(
4078
- f"WP_API void {f.mangled_name}({args}, {return_type}* ret) {{ *ret = wp::{f.key}({params}); }}",
4079
- file=file,
4292
+ file.write(
4293
+ f"WP_API void {f.mangled_name}({args}, {return_type}* ret) {{ *ret = wp::{f.key}({params}); }}\n"
4080
4294
  )
4081
4295
 
4296
+ file.write('\n} // extern "C"\n\n')
4297
+ file.write("} // namespace wp\n")
4298
+
4082
4299
 
4083
4300
  # initialize global runtime
4084
4301
  runtime = None