warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -5,36 +5,27 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- import math
9
- import os
10
- import sys
11
- import hashlib
8
+ import ast
12
9
  import ctypes
10
+ import gc
11
+ import hashlib
12
+ import inspect
13
+ import io
14
+ import os
13
15
  import platform
14
- import ast
16
+ import sys
15
17
  import types
16
- import inspect
17
-
18
- from typing import Tuple
19
- from typing import List
20
- from typing import Dict
21
- from typing import Any
22
- from typing import Callable
23
- from typing import Union
24
- from typing import Mapping
25
- from typing import Optional
26
-
18
+ from copy import copy as shallowcopy
27
19
  from types import ModuleType
20
+ from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
28
21
 
29
- from copy import copy as shallowcopy
22
+ import numpy as np
30
23
 
31
24
  import warp
32
- import warp.codegen
33
25
  import warp.build
26
+ import warp.codegen
34
27
  import warp.config
35
28
 
36
- import numpy as np
37
-
38
29
  # represents either a built-in or user-defined function
39
30
 
40
31
 
@@ -45,6 +36,18 @@ def create_value_func(type):
45
36
  return value_func
46
37
 
47
38
 
39
+ def get_function_args(func):
40
+ """Ensures that all function arguments are annotated and returns a dictionary mapping from argument name to its type."""
41
+ import inspect
42
+
43
+ argspec = inspect.getfullargspec(func)
44
+
45
+ # use source-level argument annotations
46
+ if len(argspec.annotations) < len(argspec.args):
47
+ raise RuntimeError(f"Incomplete argument annotations on function {func.__qualname__}")
48
+ return argspec.annotations
49
+
50
+
48
51
  class Function:
49
52
  def __init__(
50
53
  self,
@@ -66,8 +69,17 @@ class Function:
66
69
  generic=False,
67
70
  native_func=None,
68
71
  defaults=None,
72
+ custom_replay_func=None,
73
+ native_snippet=None,
74
+ adj_native_snippet=None,
75
+ skip_forward_codegen=False,
76
+ skip_reverse_codegen=False,
77
+ custom_reverse_num_input_args=-1,
78
+ custom_reverse_mode=False,
69
79
  overloaded_annotations=None,
70
80
  code_transformers=[],
81
+ skip_adding_overload=False,
82
+ require_original_output_arg=False,
71
83
  ):
72
84
  self.func = func # points to Python function decorated with @wp.func, may be None for builtins
73
85
  self.key = key
@@ -81,6 +93,12 @@ class Function:
81
93
  self.module = module
82
94
  self.variadic = variadic # function can take arbitrary number of inputs, e.g.: printf()
83
95
  self.defaults = defaults
96
+ # Function instance for a custom implementation of the replay pass
97
+ self.custom_replay_func = custom_replay_func
98
+ self.native_snippet = native_snippet
99
+ self.adj_native_snippet = adj_native_snippet
100
+ self.custom_grad_func = None
101
+ self.require_original_output_arg = require_original_output_arg
84
102
 
85
103
  if initializer_list_func is None:
86
104
  self.initializer_list_func = lambda x, y: False
@@ -110,7 +128,14 @@ class Function:
110
128
 
111
129
  # user defined (Python) function
112
130
  self.adj = warp.codegen.Adjoint(
113
- func, overload_annotations=overloaded_annotations, transformers=code_transformers
131
+ func,
132
+ is_user_function=True,
133
+ skip_forward_codegen=skip_forward_codegen,
134
+ skip_reverse_codegen=skip_reverse_codegen,
135
+ custom_reverse_num_input_args=custom_reverse_num_input_args,
136
+ custom_reverse_mode=custom_reverse_mode,
137
+ overload_annotations=overloaded_annotations,
138
+ transformers=code_transformers,
114
139
  )
115
140
 
116
141
  # record input types
@@ -139,11 +164,12 @@ class Function:
139
164
  else:
140
165
  self.mangled_name = None
141
166
 
142
- self.add_overload(self)
167
+ if not skip_adding_overload:
168
+ self.add_overload(self)
143
169
 
144
170
  # add to current module
145
171
  if module:
146
- module.register_function(self)
172
+ module.register_function(self, skip_adding_overload)
147
173
 
148
174
  def __call__(self, *args, **kwargs):
149
175
  # handles calling a builtin (native) function
@@ -152,121 +178,24 @@ class Function:
152
178
  # from within a kernel (experimental).
153
179
 
154
180
  if self.is_builtin() and self.mangled_name:
155
- # store last error during overload resolution
156
- error = None
157
-
158
- for f in self.overloads:
159
- if f.generic:
181
+ # For each of this function's existing overloads, we attempt to pack
182
+ # the given arguments into the C types expected by the corresponding
183
+ # parameters, and we rinse and repeat until we get a match.
184
+ for overload in self.overloads:
185
+ if overload.generic:
160
186
  continue
161
187
 
162
- # try and find builtin in the warp.dll
163
- if not hasattr(warp.context.runtime.core, f.mangled_name):
164
- raise RuntimeError(
165
- f"Couldn't find function {self.key} with mangled name {f.mangled_name} in the Warp native library"
166
- )
167
-
168
- try:
169
- # try and pack args into what the function expects
170
- params = []
171
- for i, (arg_name, arg_type) in enumerate(f.input_types.items()):
172
- a = args[i]
173
-
174
- # try to convert to a value type (vec3, mat33, etc)
175
- if issubclass(arg_type, ctypes.Array):
176
- # wrap the arg_type (which is an ctypes.Array) in a structure
177
- # to ensure parameter is passed to the .dll by value rather than reference
178
- class ValueArg(ctypes.Structure):
179
- _fields_ = [("value", arg_type)]
180
-
181
- x = ValueArg()
182
-
183
- # force conversion to ndarray first (handles tuple / list, Gf.Vec3 case)
184
- if isinstance(a, ctypes.Array) == False:
185
- # assume you want the float32 version of the function so it doesn't just
186
- # grab an override for a random data type:
187
- if arg_type._type_ != ctypes.c_float:
188
- raise RuntimeError(
189
- f"Error calling function '{f.key}', parameter for argument '{arg_name}' does not have c_float type."
190
- )
191
-
192
- a = np.array(a)
193
-
194
- # flatten to 1D array
195
- v = a.flatten()
196
- if len(v) != arg_type._length_:
197
- raise RuntimeError(
198
- f"Error calling function '{f.key}', parameter for argument '{arg_name}' has length {len(v)}, but expected {arg_type._length_}. Could not convert parameter to {arg_type}."
199
- )
200
-
201
- for i in range(arg_type._length_):
202
- x.value[i] = v[i]
203
-
204
- else:
205
- # already a built-in type, check it matches
206
- if not warp.types.types_equal(type(a), arg_type):
207
- raise RuntimeError(
208
- f"Error calling function '{f.key}', parameter for argument '{arg_name}' has type '{type(a)}' but expected '{arg_type}'"
209
- )
210
-
211
- x.value = a
212
-
213
- params.append(x)
214
-
215
- else:
216
- try:
217
- # try to pack as a scalar type
218
- params.append(arg_type._type_(a))
219
- except:
220
- raise RuntimeError(
221
- f"Error calling function {f.key}, unable to pack function parameter type {type(a)} for param {arg_name}, expected {arg_type}"
222
- )
223
-
224
- # returns the corresponding ctype for a scalar or vector warp type
225
- def type_ctype(dtype):
226
- if dtype == float:
227
- return ctypes.c_float
228
- elif dtype == int:
229
- return ctypes.c_int32
230
- elif issubclass(dtype, ctypes.Array):
231
- return dtype
232
- elif issubclass(dtype, ctypes.Structure):
233
- return dtype
234
- else:
235
- # scalar type
236
- return dtype._type_
237
-
238
- value_type = type_ctype(f.value_func(None, None, None))
239
-
240
- # construct return value (passed by address)
241
- ret = value_type()
242
- ret_addr = ctypes.c_void_p(ctypes.addressof(ret))
243
-
244
- params.append(ret_addr)
245
-
246
- c_func = getattr(warp.context.runtime.core, f.mangled_name)
247
- c_func(*params)
248
-
249
- if issubclass(value_type, ctypes.Array) or issubclass(value_type, ctypes.Structure):
250
- # return vector types as ctypes
251
- return ret
252
- else:
253
- # return scalar types as int/float
254
- return ret.value
255
-
256
- except Exception as e:
257
- # couldn't pack values to match this overload
258
- # store error and move onto the next one
259
- error = e
260
- continue
188
+ success, return_value = call_builtin(overload, *args)
189
+ if success:
190
+ return return_value
261
191
 
262
192
  # overload resolution or call failed
263
- # raise the last exception encountered
264
- if error:
265
- raise error
266
- else:
267
- raise RuntimeError(f"Error calling function '{f.key}'.")
193
+ raise RuntimeError(
194
+ f"Couldn't find a function '{self.key}' compatible with "
195
+ f"the arguments '{', '.join(type(x).__name__ for x in args)}'"
196
+ )
268
197
 
269
- elif hasattr(self, "user_overloads") and len(self.user_overloads):
198
+ if hasattr(self, "user_overloads") and len(self.user_overloads):
270
199
  # user-defined function with overloads
271
200
 
272
201
  if len(kwargs):
@@ -275,28 +204,26 @@ class Function:
275
204
  )
276
205
 
277
206
  # try and find a matching overload
278
- for f in self.user_overloads.values():
279
- if len(f.input_types) != len(args):
207
+ for overload in self.user_overloads.values():
208
+ if len(overload.input_types) != len(args):
280
209
  continue
281
- template_types = list(f.input_types.values())
282
- arg_names = list(f.input_types.keys())
210
+ template_types = list(overload.input_types.values())
211
+ arg_names = list(overload.input_types.keys())
283
212
  try:
284
213
  # attempt to unify argument types with function template types
285
214
  warp.types.infer_argument_types(args, template_types, arg_names)
286
- return f.func(*args)
215
+ return overload.func(*args)
287
216
  except Exception:
288
217
  continue
289
218
 
290
219
  raise RuntimeError(f"Error calling function '{self.key}', no overload found for arguments {args}")
291
220
 
292
- else:
293
- # user-defined function with no overloads
294
-
295
- if self.func is None:
296
- raise RuntimeError(f"Error calling function '{self.key}', function is undefined")
221
+ # user-defined function with no overloads
222
+ if self.func is None:
223
+ raise RuntimeError(f"Error calling function '{self.key}', function is undefined")
297
224
 
298
- # this function has no overloads, call it like a plain Python function
299
- return self.func(*args, **kwargs)
225
+ # this function has no overloads, call it like a plain Python function
226
+ return self.func(*args, **kwargs)
300
227
 
301
228
  def is_builtin(self):
302
229
  return self.func is None
@@ -316,7 +243,7 @@ class Function:
316
243
  # todo: construct a default value for each of the functions args
317
244
  # so we can generate the return type for overloaded functions
318
245
  return_type = type_str(self.value_func(None, None, None))
319
- except:
246
+ except Exception:
320
247
  return False
321
248
 
322
249
  if return_type.startswith("Tuple"):
@@ -409,10 +336,187 @@ class Function:
409
336
  return None
410
337
 
411
338
  def __repr__(self):
412
- inputs_str = ", ".join([f"{k}: {v.__name__}" for k, v in self.input_types.items()])
339
+ inputs_str = ", ".join([f"{k}: {warp.types.type_repr(v)}" for k, v in self.input_types.items()])
413
340
  return f"<Function {self.key}({inputs_str})>"
414
341
 
415
342
 
343
+ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
344
+ uses_non_warp_array_type = False
345
+
346
+ # Retrieve the built-in function from Warp's dll.
347
+ c_func = getattr(warp.context.runtime.core, func.mangled_name)
348
+
349
+ # Try gathering the parameters that the function expects and pack them
350
+ # into their corresponding C types.
351
+ c_params = []
352
+ for i, (_, arg_type) in enumerate(func.input_types.items()):
353
+ param = params[i]
354
+
355
+ try:
356
+ iter(param)
357
+ except TypeError:
358
+ is_array = False
359
+ else:
360
+ is_array = True
361
+
362
+ if is_array:
363
+ if not issubclass(arg_type, ctypes.Array):
364
+ return (False, None)
365
+
366
+ # The argument expects a built-in Warp type like a vector or a matrix.
367
+
368
+ c_param = None
369
+
370
+ if isinstance(param, ctypes.Array):
371
+ # The given parameter is also a built-in Warp type, so we only need
372
+ # to make sure that it matches with the argument.
373
+ if not warp.types.types_equal(type(param), arg_type):
374
+ return (False, None)
375
+
376
+ if isinstance(param, arg_type):
377
+ c_param = param
378
+ else:
379
+ # Cast the value to its argument type to make sure that it
380
+ # can be assigned to the field of the `Param` struct.
381
+ # This could error otherwise when, for example, the field type
382
+ # is set to `vec3i` while the value is of type `vector(length=3, dtype=int)`,
383
+ # even though both types are semantically identical.
384
+ c_param = arg_type(param)
385
+ else:
386
+ # Flatten the parameter values into a flat 1-D array.
387
+ arr = []
388
+ ndim = 1
389
+ stack = [(0, param)]
390
+ while stack:
391
+ depth, elem = stack.pop(0)
392
+ try:
393
+ # If `elem` is a sequence, then it should be possible
394
+ # to add its elements to the stack for later processing.
395
+ stack.extend((depth + 1, x) for x in elem)
396
+ except TypeError:
397
+ # Since `elem` doesn't seem to be a sequence,
398
+ # we must have a leaf value that we need to add to our
399
+ # resulting array.
400
+ arr.append(elem)
401
+ ndim = max(depth, ndim)
402
+
403
+ assert ndim > 0
404
+
405
+ # Ensure that if the given parameter value is, say, a 2-D array,
406
+ # then we try to resolve it against a matrix argument rather than
407
+ # a vector.
408
+ if ndim > len(arg_type._shape_):
409
+ return (False, None)
410
+
411
+ elem_count = len(arr)
412
+ if elem_count != arg_type._length_:
413
+ return (False, None)
414
+
415
+ # Retrieve the element type of the sequence while ensuring
416
+ # that it's homogeneous.
417
+ elem_type = type(arr[0])
418
+ for i in range(1, elem_count):
419
+ if type(arr[i]) is not elem_type:
420
+ raise ValueError("All array elements must share the same type.")
421
+
422
+ expected_elem_type = arg_type._wp_scalar_type_
423
+ if not (
424
+ elem_type is expected_elem_type
425
+ or (elem_type is float and expected_elem_type is warp.types.float32)
426
+ or (elem_type is int and expected_elem_type is warp.types.int32)
427
+ or (
428
+ issubclass(elem_type, np.number)
429
+ and warp.types.np_dtype_to_warp_type[np.dtype(elem_type)] is expected_elem_type
430
+ )
431
+ ):
432
+ # The parameter value has a type not matching the type defined
433
+ # for the corresponding argument.
434
+ return (False, None)
435
+
436
+ if elem_type in warp.types.int_types:
437
+ # Pass the value through the expected integer type
438
+ # in order to evaluate any integer wrapping.
439
+ # For example `uint8(-1)` should result in the value `-255`.
440
+ arr = tuple(elem_type._type_(x.value).value for x in arr)
441
+ elif elem_type in warp.types.float_types:
442
+ # Extract the floating-point values.
443
+ arr = tuple(x.value for x in arr)
444
+
445
+ c_param = arg_type()
446
+ if warp.types.type_is_matrix(arg_type):
447
+ rows, cols = arg_type._shape_
448
+ for i in range(rows):
449
+ idx_start = i * cols
450
+ idx_end = idx_start + cols
451
+ c_param[i] = arr[idx_start:idx_end]
452
+ else:
453
+ c_param[:] = arr
454
+
455
+ uses_non_warp_array_type = True
456
+
457
+ c_params.append(ctypes.byref(c_param))
458
+ else:
459
+ if issubclass(arg_type, ctypes.Array):
460
+ return (False, None)
461
+
462
+ if not (
463
+ isinstance(param, arg_type)
464
+ or (type(param) is float and arg_type is warp.types.float32)
465
+ or (type(param) is int and arg_type is warp.types.int32)
466
+ or warp.types.np_dtype_to_warp_type.get(getattr(param, "dtype", None)) is arg_type
467
+ ):
468
+ return (False, None)
469
+
470
+ if type(param) in warp.types.scalar_types:
471
+ param = param.value
472
+
473
+ # try to pack as a scalar type
474
+ if arg_type == warp.types.float16:
475
+ c_params.append(arg_type._type_(warp.types.float_to_half_bits(param)))
476
+ else:
477
+ c_params.append(arg_type._type_(param))
478
+
479
+ # returns the corresponding ctype for a scalar or vector warp type
480
+ value_type = func.value_func(None, None, None)
481
+ if value_type == float:
482
+ value_ctype = ctypes.c_float
483
+ elif value_type == int:
484
+ value_ctype = ctypes.c_int32
485
+ elif issubclass(value_type, (ctypes.Array, ctypes.Structure)):
486
+ value_ctype = value_type
487
+ else:
488
+ # scalar type
489
+ value_ctype = value_type._type_
490
+
491
+ # construct return value (passed by address)
492
+ ret = value_ctype()
493
+ ret_addr = ctypes.c_void_p(ctypes.addressof(ret))
494
+ c_params.append(ret_addr)
495
+
496
+ # Call the built-in function from Warp's dll.
497
+ c_func(*c_params)
498
+
499
+ if uses_non_warp_array_type:
500
+ warp.utils.warn(
501
+ "Support for built-in functions called with non-Warp array types, "
502
+ "such as lists, tuples, NumPy arrays, and others, will be dropped "
503
+ "in the future. Use a Warp type such as `wp.vec`, `wp.mat`, "
504
+ "`wp.quat`, or `wp.transform`.",
505
+ DeprecationWarning,
506
+ stacklevel=3,
507
+ )
508
+
509
+ if issubclass(value_ctype, ctypes.Array) or issubclass(value_ctype, ctypes.Structure):
510
+ # return vector types as ctypes
511
+ return (True, ret)
512
+
513
+ if value_type == warp.types.float16:
514
+ return (True, warp.types.half_bits_to_float(ret.value))
515
+
516
+ # return scalar types as int/float
517
+ return (True, ret.value)
518
+
519
+
416
520
  class KernelHooks:
417
521
  def __init__(self, forward, backward):
418
522
  self.forward = forward
@@ -421,10 +525,20 @@ class KernelHooks:
421
525
 
422
526
  # caches source and compiled entry points for a kernel (will be populated after module loads)
423
527
  class Kernel:
424
- def __init__(self, func, key, module, options=None, code_transformers=[]):
528
+ def __init__(self, func, key=None, module=None, options=None, code_transformers=[]):
425
529
  self.func = func
426
- self.module = module
427
- self.key = key
530
+
531
+ if module is None:
532
+ self.module = get_module(func.__module__)
533
+ else:
534
+ self.module = module
535
+
536
+ if key is None:
537
+ unique_key = self.module.generate_unique_kernel_key(func.__name__)
538
+ self.key = unique_key
539
+ else:
540
+ self.key = key
541
+
428
542
  self.options = {} if options is None else options
429
543
 
430
544
  self.adj = warp.codegen.Adjoint(func, transformers=code_transformers)
@@ -445,8 +559,8 @@ class Kernel:
445
559
  # argument indices by name
446
560
  self.arg_indices = dict((a.label, i) for i, a in enumerate(self.adj.args))
447
561
 
448
- if module:
449
- module.register_kernel(self)
562
+ if self.module:
563
+ self.module.register_kernel(self)
450
564
 
451
565
  def infer_argument_types(self, args):
452
566
  template_types = list(self.adj.arg_types.values())
@@ -523,7 +637,7 @@ def func(f):
523
637
  name = warp.codegen.make_full_qualified_name(f)
524
638
 
525
639
  m = get_module(f.__module__)
526
- func = Function(
640
+ Function(
527
641
  func=f, key=name, namespace="", module=m, value_func=None
528
642
  ) # value_type not known yet, will be inferred during Adjoint.build()
529
643
 
@@ -531,6 +645,167 @@ def func(f):
531
645
  return m.functions[name]
532
646
 
533
647
 
648
+ def func_native(snippet, adj_snippet=None):
649
+ """
650
+ Decorator to register native code snippet, @func_native
651
+ """
652
+
653
+ def snippet_func(f):
654
+ name = warp.codegen.make_full_qualified_name(f)
655
+
656
+ m = get_module(f.__module__)
657
+ func = Function(
658
+ func=f, key=name, namespace="", module=m, native_snippet=snippet, adj_native_snippet=adj_snippet
659
+ ) # cuda snippets do not have a return value_type
660
+
661
+ return m.functions[name]
662
+
663
+ return snippet_func
664
+
665
+
666
+ def func_grad(forward_fn):
667
+ """
668
+ Decorator to register a custom gradient function for a given forward function.
669
+ The function signature must correspond to one of the function overloads in the following way:
670
+ the first part of the input arguments are the original input variables with the same types as their
671
+ corresponding arguments in the original function, and the second part of the input arguments are the
672
+ adjoint variables of the output variables (if available) of the original function with the same types as the
673
+ output variables. The function must not return anything.
674
+ """
675
+
676
+ def wrapper(grad_fn):
677
+ generic = any(warp.types.type_is_generic(x) for x in forward_fn.input_types.values())
678
+ if generic:
679
+ raise RuntimeError(
680
+ f"Cannot define custom grad definition for {forward_fn.key} since functions with generic input arguments are not yet supported."
681
+ )
682
+
683
+ reverse_args = {}
684
+ reverse_args.update(forward_fn.input_types)
685
+
686
+ # create temporary Adjoint instance to analyze the function signature
687
+ adj = warp.codegen.Adjoint(
688
+ grad_fn, skip_forward_codegen=True, skip_reverse_codegen=False, transformers=forward_fn.adj.transformers
689
+ )
690
+
691
+ from warp.types import types_equal
692
+
693
+ grad_args = adj.args
694
+ grad_sig = warp.types.get_signature([arg.type for arg in grad_args], func_name=forward_fn.key)
695
+
696
+ generic = any(warp.types.type_is_generic(x.type) for x in grad_args)
697
+ if generic:
698
+ raise RuntimeError(
699
+ f"Cannot define custom grad definition for {forward_fn.key} since the provided grad function has generic input arguments."
700
+ )
701
+
702
+ def match_function(f):
703
+ # check whether the function overload f matches the signature of the provided gradient function
704
+ if not hasattr(f.adj, "return_var"):
705
+ f.adj.build(None)
706
+ expected_args = list(f.input_types.items())
707
+ if f.adj.return_var is not None:
708
+ expected_args += [(f"adj_ret_{var.label}", var.type) for var in f.adj.return_var]
709
+ if len(grad_args) != len(expected_args):
710
+ return False
711
+ if any(not types_equal(a.type, exp_type) for a, (_, exp_type) in zip(grad_args, expected_args)):
712
+ return False
713
+ return True
714
+
715
+ def add_custom_grad(f: Function):
716
+ # register custom gradient function
717
+ f.custom_grad_func = Function(
718
+ grad_fn,
719
+ key=f.key,
720
+ namespace=f.namespace,
721
+ input_types=reverse_args,
722
+ value_func=None,
723
+ module=f.module,
724
+ template_func=f.template_func,
725
+ skip_forward_codegen=True,
726
+ custom_reverse_mode=True,
727
+ custom_reverse_num_input_args=len(f.input_types),
728
+ skip_adding_overload=False,
729
+ code_transformers=f.adj.transformers,
730
+ )
731
+ f.adj.skip_reverse_codegen = True
732
+
733
+ if hasattr(forward_fn, "user_overloads") and len(forward_fn.user_overloads):
734
+ # find matching overload for which this grad function is defined
735
+ for sig, f in forward_fn.user_overloads.items():
736
+ if not grad_sig.startswith(sig):
737
+ continue
738
+ if match_function(f):
739
+ add_custom_grad(f)
740
+ return
741
+ raise RuntimeError(
742
+ f"No function overload found for gradient function {grad_fn.__qualname__} for function {forward_fn.key}"
743
+ )
744
+ else:
745
+ # resolve return variables
746
+ forward_fn.adj.build(None)
747
+
748
+ expected_args = list(forward_fn.input_types.items())
749
+ if forward_fn.adj.return_var is not None:
750
+ expected_args += [(f"adj_ret_{var.label}", var.type) for var in forward_fn.adj.return_var]
751
+
752
+ # check if the signature matches this function
753
+ if match_function(forward_fn):
754
+ add_custom_grad(forward_fn)
755
+ else:
756
+ raise RuntimeError(
757
+ f"Gradient function {grad_fn.__qualname__} for function {forward_fn.key} has an incorrect signature. The arguments must match the "
758
+ "forward function arguments plus the adjoint variables corresponding to the return variables:"
759
+ f"\n{', '.join(map(lambda nt: f'{nt[0]}: {nt[1].__name__}', expected_args))}"
760
+ )
761
+
762
+ return wrapper
763
+
764
+
765
+ def func_replay(forward_fn):
766
+ """
767
+ Decorator to register a custom replay function for a given forward function.
768
+ The replay function is the function version that is called in the forward phase of the backward pass (replay mode) and corresponds to the forward function by default.
769
+ The provided function has to match the signature of one of the original forward function overloads.
770
+ """
771
+
772
+ def wrapper(replay_fn):
773
+ generic = any(warp.types.type_is_generic(x) for x in forward_fn.input_types.values())
774
+ if generic:
775
+ raise RuntimeError(
776
+ f"Cannot define custom replay definition for {forward_fn.key} since functions with generic input arguments are not yet supported."
777
+ )
778
+
779
+ args = get_function_args(replay_fn)
780
+ arg_types = list(args.values())
781
+ generic = any(warp.types.type_is_generic(x) for x in arg_types)
782
+ if generic:
783
+ raise RuntimeError(
784
+ f"Cannot define custom replay definition for {forward_fn.key} since the provided replay function has generic input arguments."
785
+ )
786
+
787
+ f = forward_fn.get_overload(arg_types)
788
+ if f is None:
789
+ inputs_str = ", ".join([f"{k}: {v.__name__}" for k, v in args.items()])
790
+ raise RuntimeError(
791
+ f"Could not find forward definition of function {forward_fn.key} that matches custom replay definition with arguments:\n{inputs_str}"
792
+ )
793
+ f.custom_replay_func = Function(
794
+ replay_fn,
795
+ key=f"replay_{f.key}",
796
+ namespace=f.namespace,
797
+ input_types=f.input_types,
798
+ value_func=f.value_func,
799
+ module=f.module,
800
+ template_func=f.template_func,
801
+ skip_reverse_codegen=True,
802
+ skip_adding_overload=True,
803
+ code_transformers=f.adj.transformers,
804
+ )
805
+
806
+ return wrapper
807
+
808
+
534
809
  # decorator to register kernel, @kernel, custom_name may be a string
535
810
  # that creates a kernel with a different name from the actual function
536
811
  def kernel(f=None, *, enable_backward=None):
@@ -658,6 +933,7 @@ def add_builtin(
658
933
  missing_grad=False,
659
934
  native_func=None,
660
935
  defaults=None,
936
+ require_original_output_arg=False,
661
937
  ):
662
938
  # wrap simple single-type functions with a value_func()
663
939
  if value_func is None:
@@ -670,7 +946,7 @@ def add_builtin(
670
946
  def initializer_list_func(args, templates):
671
947
  return False
672
948
 
673
- if defaults == None:
949
+ if defaults is None:
674
950
  defaults = {}
675
951
 
676
952
  # Add specialized versions of this builtin if it's generic by matching arguments against
@@ -751,8 +1027,8 @@ def add_builtin(
751
1027
  # on the generated argument list and skip generation if it fails.
752
1028
  # This also gives us the return type, which we keep for later:
753
1029
  try:
754
- return_type = value_func([warp.codegen.Var("", t) for t in argtypes], {}, [])
755
- except Exception as e:
1030
+ return_type = value_func(argtypes, {}, [])
1031
+ except Exception:
756
1032
  continue
757
1033
 
758
1034
  # The return_type might just be vector_t(length=3,dtype=wp.float32), so we've got to match that
@@ -782,6 +1058,7 @@ def add_builtin(
782
1058
  hidden=True,
783
1059
  skip_replay=skip_replay,
784
1060
  missing_grad=missing_grad,
1061
+ require_original_output_arg=require_original_output_arg,
785
1062
  )
786
1063
 
787
1064
  func = Function(
@@ -802,6 +1079,7 @@ def add_builtin(
802
1079
  generic=generic,
803
1080
  native_func=native_func,
804
1081
  defaults=defaults,
1082
+ require_original_output_arg=require_original_output_arg,
805
1083
  )
806
1084
 
807
1085
  if key in builtin_functions:
@@ -811,7 +1089,7 @@ def add_builtin(
811
1089
 
812
1090
  # export means the function will be added to the `warp` module namespace
813
1091
  # so that users can call it directly from the Python interpreter
814
- if export == True:
1092
+ if export:
815
1093
  if hasattr(warp, key):
816
1094
  # check that we haven't already created something at this location
817
1095
  # if it's just an overload stub for auto-complete then overwrite it
@@ -878,6 +1156,8 @@ class ModuleBuilder:
878
1156
  for func in module.functions.values():
879
1157
  for f in func.user_overloads.values():
880
1158
  self.build_function(f)
1159
+ if f.custom_replay_func is not None:
1160
+ self.build_function(f.custom_replay_func)
881
1161
 
882
1162
  # build all kernel entry points
883
1163
  for kernel in module.kernels.values():
@@ -894,8 +1174,7 @@ class ModuleBuilder:
894
1174
  while stack:
895
1175
  s = stack.pop()
896
1176
 
897
- if not s in structs:
898
- structs.append(s)
1177
+ structs.append(s)
899
1178
 
900
1179
  for var in s.vars.values():
901
1180
  if isinstance(var.type, warp.codegen.Struct):
@@ -927,7 +1206,7 @@ class ModuleBuilder:
927
1206
  if not func.value_func:
928
1207
 
929
1208
  def wrap(adj):
930
- def value_type(args, kwds, templates):
1209
+ def value_type(arg_types, kwds, templates):
931
1210
  if adj.return_var is None or len(adj.return_var) == 0:
932
1211
  return None
933
1212
  if len(adj.return_var) == 1:
@@ -951,7 +1230,14 @@ class ModuleBuilder:
951
1230
 
952
1231
  # code-gen all imported functions
953
1232
  for func in self.functions.keys():
954
- source += warp.codegen.codegen_func(func.adj, name=func.key, device=device, options=self.options)
1233
+ if func.native_snippet is None:
1234
+ source += warp.codegen.codegen_func(
1235
+ func.adj, c_func_name=func.native_func, device=device, options=self.options
1236
+ )
1237
+ else:
1238
+ source += warp.codegen.codegen_snippet(
1239
+ func.adj, name=func.key, snippet=func.native_snippet, adj_snippet=func.adj_native_snippet
1240
+ )
955
1241
 
956
1242
  for kernel in self.module.kernels.values():
957
1243
  # each kernel gets an entry point in the module
@@ -1031,6 +1317,10 @@ class Module:
1031
1317
 
1032
1318
  self.content_hash = None
1033
1319
 
1320
+ # number of times module auto-generates kernel key for user
1321
+ # used to ensure unique kernel keys
1322
+ self.count = 0
1323
+
1034
1324
  def register_struct(self, struct):
1035
1325
  self.structs[struct.key] = struct
1036
1326
 
@@ -1045,7 +1335,7 @@ class Module:
1045
1335
  # for a reload of module on next launch
1046
1336
  self.unload()
1047
1337
 
1048
- def register_function(self, func):
1338
+ def register_function(self, func, skip_adding_overload=False):
1049
1339
  if func.key not in self.functions:
1050
1340
  self.functions[func.key] = func
1051
1341
  else:
@@ -1065,7 +1355,7 @@ class Module:
1065
1355
  )
1066
1356
  if sig == sig_existing:
1067
1357
  self.functions[func.key] = func
1068
- else:
1358
+ elif not skip_adding_overload:
1069
1359
  func_existing.add_overload(func)
1070
1360
 
1071
1361
  self.find_references(func.adj)
@@ -1073,6 +1363,11 @@ class Module:
1073
1363
  # for a reload of module on next launch
1074
1364
  self.unload()
1075
1365
 
1366
+ def generate_unique_kernel_key(self, key):
1367
+ unique_key = f"{key}_{self.count}"
1368
+ self.count += 1
1369
+ return unique_key
1370
+
1076
1371
  # collect all referenced functions / structs
1077
1372
  # given the AST of a function or kernel
1078
1373
  def find_references(self, adj):
@@ -1086,13 +1381,13 @@ class Module:
1086
1381
  if isinstance(node, ast.Call):
1087
1382
  try:
1088
1383
  # try to resolve the function
1089
- func, _ = adj.resolve_path(node.func)
1384
+ func, _ = adj.resolve_static_expression(node.func, eval_types=False)
1090
1385
 
1091
1386
  # if this is a user-defined function, add a module reference
1092
1387
  if isinstance(func, warp.context.Function) and func.module is not None:
1093
1388
  add_ref(func.module)
1094
1389
 
1095
- except:
1390
+ except Exception:
1096
1391
  # Lookups may fail for builtins, but that's ok.
1097
1392
  # Lookups may also fail for functions in this module that haven't been imported yet,
1098
1393
  # and that's ok too (not an external reference).
@@ -1139,9 +1434,24 @@ class Module:
1139
1434
  s = func.adj.source
1140
1435
  ch.update(bytes(s, "utf-8"))
1141
1436
 
1437
+ if func.custom_grad_func:
1438
+ s = func.custom_grad_func.adj.source
1439
+ ch.update(bytes(s, "utf-8"))
1440
+ if func.custom_replay_func:
1441
+ s = func.custom_replay_func.adj.source
1442
+
1443
+ # cache func arg types
1444
+ for arg, arg_type in func.adj.arg_types.items():
1445
+ s = f"{arg}: {get_type_name(arg_type)}"
1446
+ ch.update(bytes(s, "utf-8"))
1447
+
1142
1448
  # kernel source
1143
1449
  for kernel in module.kernels.values():
1144
1450
  ch.update(bytes(kernel.adj.source, "utf-8"))
1451
+ # cache kernel arg types
1452
+ for arg, arg_type in kernel.adj.arg_types.items():
1453
+ s = f"{arg}: {get_type_name(arg_type)}"
1454
+ ch.update(bytes(s, "utf-8"))
1145
1455
  # for generic kernels the Python source is always the same,
1146
1456
  # but we hash the type signatures of all the overloads
1147
1457
  if kernel.is_generic:
@@ -1440,13 +1750,13 @@ class ContextGuard:
1440
1750
  def __enter__(self):
1441
1751
  if self.device.is_cuda:
1442
1752
  runtime.core.cuda_context_push_current(self.device.context)
1443
- elif is_cuda_available():
1753
+ elif is_cuda_driver_initialized():
1444
1754
  self.saved_context = runtime.core.cuda_context_get_current()
1445
1755
 
1446
1756
  def __exit__(self, exc_type, exc_value, traceback):
1447
1757
  if self.device.is_cuda:
1448
1758
  runtime.core.cuda_context_pop_current()
1449
- elif is_cuda_available():
1759
+ elif is_cuda_driver_initialized():
1450
1760
  runtime.core.cuda_context_set_current(self.saved_context)
1451
1761
 
1452
1762
 
@@ -1537,6 +1847,29 @@ class Event:
1537
1847
 
1538
1848
 
1539
1849
  class Device:
1850
+ """A device to allocate Warp arrays and to launch kernels on.
1851
+
1852
+ Attributes:
1853
+ ordinal: A Warp-specific integer label for the device. ``-1`` for CPU devices.
1854
+ name: A string label for the device. By default, CPU devices will be named according to the processor name,
1855
+ or ``"CPU"`` if the processor name cannot be determined.
1856
+ arch: An integer representing the compute capability version number calculated as
1857
+ ``10 * major + minor``. ``0`` for CPU devices.
1858
+ is_uva: A boolean indicating whether or not the device supports unified addressing.
1859
+ ``False`` for CPU devices.
1860
+ is_cubin_supported: A boolean indicating whether or not Warp's version of NVRTC can directly
1861
+ generate CUDA binary files (cubin) for this device's architecture. ``False`` for CPU devices.
1862
+ is_mempool_supported: A boolean indicating whether or not the device supports using the
1863
+ ``cuMemAllocAsync`` and ``cuMemPool`` family of APIs for stream-ordered memory allocations. ``False`` for
1864
+ CPU devices.
1865
+ is_primary: A boolean indicating whether or not this device's CUDA context is also the
1866
+ device's primary context.
1867
+ uuid: A string representing the UUID of the CUDA device. The UUID is in the same format used by
1868
+ ``nvidia-smi -L``. ``None`` for CPU devices.
1869
+ pci_bus_id: A string identifier for the CUDA device in the format ``[domain]:[bus]:[device]``, in which
1870
+ ``domain``, ``bus``, and ``device`` are all hexadecimal values. ``None`` for CPU devices.
1871
+ """
1872
+
1540
1873
  def __init__(self, runtime, alias, ordinal=-1, is_primary=False, context=None):
1541
1874
  self.runtime = runtime
1542
1875
  self.alias = alias
@@ -1566,6 +1899,9 @@ class Device:
1566
1899
  self.arch = 0
1567
1900
  self.is_uva = False
1568
1901
  self.is_cubin_supported = False
1902
+ self.is_mempool_supported = False
1903
+ self.uuid = None
1904
+ self.pci_bus_id = None
1569
1905
 
1570
1906
  # TODO: add more device-specific dispatch functions
1571
1907
  self.memset = runtime.core.memset_host
@@ -1578,6 +1914,26 @@ class Device:
1578
1914
  self.is_uva = runtime.core.cuda_device_is_uva(ordinal)
1579
1915
  # check whether our NVRTC can generate CUBINs for this architecture
1580
1916
  self.is_cubin_supported = self.arch in runtime.nvrtc_supported_archs
1917
+ self.is_mempool_supported = runtime.core.cuda_device_is_memory_pool_supported(ordinal)
1918
+
1919
+ uuid_buffer = (ctypes.c_char * 16)()
1920
+ runtime.core.cuda_device_get_uuid(ordinal, uuid_buffer)
1921
+ uuid_byte_str = bytes(uuid_buffer).hex()
1922
+ self.uuid = f"GPU-{uuid_byte_str[0:8]}-{uuid_byte_str[8:12]}-{uuid_byte_str[12:16]}-{uuid_byte_str[16:20]}-{uuid_byte_str[20:]}"
1923
+
1924
+ pci_domain_id = runtime.core.cuda_device_get_pci_domain_id(ordinal)
1925
+ pci_bus_id = runtime.core.cuda_device_get_pci_bus_id(ordinal)
1926
+ pci_device_id = runtime.core.cuda_device_get_pci_device_id(ordinal)
1927
+ # This is (mis)named to correspond to the naming of cudaDeviceGetPCIBusId
1928
+ self.pci_bus_id = f"{pci_domain_id:08X}:{pci_bus_id:02X}:{pci_device_id:02X}"
1929
+
1930
+ # Warn the user of a possible misconfiguration of their system
1931
+ if not self.is_mempool_supported:
1932
+ warp.utils.warn(
1933
+ f"Support for stream ordered memory allocators was not detected on device {ordinal}. "
1934
+ "This can prevent the use of graphs and/or result in poor performance. "
1935
+ "Is the UVM driver enabled?"
1936
+ )
1581
1937
 
1582
1938
  # initialize streams unless context acquisition is postponed
1583
1939
  if self._context is not None:
@@ -1601,14 +1957,17 @@ class Device:
1601
1957
 
1602
1958
  @property
1603
1959
  def is_cpu(self):
1960
+ """A boolean indicating whether or not the device is a CPU device."""
1604
1961
  return self.ordinal < 0
1605
1962
 
1606
1963
  @property
1607
1964
  def is_cuda(self):
1965
+ """A boolean indicating whether or not the device is a CUDA device."""
1608
1966
  return self.ordinal >= 0
1609
1967
 
1610
1968
  @property
1611
1969
  def context(self):
1970
+ """The context associated with the device."""
1612
1971
  if self._context is not None:
1613
1972
  return self._context
1614
1973
  elif self.is_primary:
@@ -1623,10 +1982,16 @@ class Device:
1623
1982
 
1624
1983
  @property
1625
1984
  def has_context(self):
1985
+ """A boolean indicating whether or not the device has a CUDA context associated with it."""
1626
1986
  return self._context is not None
1627
1987
 
1628
1988
  @property
1629
1989
  def stream(self):
1990
+ """The stream associated with a CUDA device.
1991
+
1992
+ Raises:
1993
+ RuntimeError: The device is not a CUDA device.
1994
+ """
1630
1995
  if self.context:
1631
1996
  return self._stream
1632
1997
  else:
@@ -1644,6 +2009,7 @@ class Device:
1644
2009
 
1645
2010
  @property
1646
2011
  def has_stream(self):
2012
+ """A boolean indicating whether or not the device has a stream associated with it."""
1647
2013
  return self._stream is not None
1648
2014
 
1649
2015
  def __str__(self):
@@ -1721,7 +2087,7 @@ class Runtime:
1721
2087
 
1722
2088
  self.core = self.load_dll(warp_lib)
1723
2089
 
1724
- if llvm_lib and os.path.exists(llvm_lib):
2090
+ if os.path.exists(llvm_lib):
1725
2091
  self.llvm = self.load_dll(llvm_lib)
1726
2092
  # setup c-types for warp-clang.dll
1727
2093
  self.llvm.lookup.restype = ctypes.c_uint64
@@ -2087,6 +2453,8 @@ class Runtime:
2087
2453
  self.core.cuda_driver_version.restype = ctypes.c_int
2088
2454
  self.core.cuda_toolkit_version.argtypes = None
2089
2455
  self.core.cuda_toolkit_version.restype = ctypes.c_int
2456
+ self.core.cuda_driver_is_initialized.argtypes = None
2457
+ self.core.cuda_driver_is_initialized.restype = ctypes.c_bool
2090
2458
 
2091
2459
  self.core.nvrtc_supported_arch_count.argtypes = None
2092
2460
  self.core.nvrtc_supported_arch_count.restype = ctypes.c_int
@@ -2103,6 +2471,14 @@ class Runtime:
2103
2471
  self.core.cuda_device_get_arch.restype = ctypes.c_int
2104
2472
  self.core.cuda_device_is_uva.argtypes = [ctypes.c_int]
2105
2473
  self.core.cuda_device_is_uva.restype = ctypes.c_int
2474
+ self.core.cuda_device_get_uuid.argtypes = [ctypes.c_int, ctypes.c_char * 16]
2475
+ self.core.cuda_device_get_uuid.restype = None
2476
+ self.core.cuda_device_get_pci_domain_id.argtypes = [ctypes.c_int]
2477
+ self.core.cuda_device_get_pci_domain_id.restype = ctypes.c_int
2478
+ self.core.cuda_device_get_pci_bus_id.argtypes = [ctypes.c_int]
2479
+ self.core.cuda_device_get_pci_bus_id.restype = ctypes.c_int
2480
+ self.core.cuda_device_get_pci_device_id.argtypes = [ctypes.c_int]
2481
+ self.core.cuda_device_get_pci_device_id.restype = ctypes.c_int
2106
2482
 
2107
2483
  self.core.cuda_context_get_current.argtypes = None
2108
2484
  self.core.cuda_context_get_current.restype = ctypes.c_void_p
@@ -2189,6 +2565,7 @@ class Runtime:
2189
2565
  ctypes.c_void_p,
2190
2566
  ctypes.c_void_p,
2191
2567
  ctypes.c_size_t,
2568
+ ctypes.c_int,
2192
2569
  ctypes.POINTER(ctypes.c_void_p),
2193
2570
  ]
2194
2571
  self.core.cuda_launch_kernel.restype = ctypes.c_size_t
@@ -2309,8 +2686,15 @@ class Runtime:
2309
2686
  dll = ctypes.CDLL(dll_path, winmode=0)
2310
2687
  else:
2311
2688
  dll = ctypes.CDLL(dll_path)
2312
- except OSError:
2313
- raise RuntimeError(f"Failed to load the shared library '{dll_path}'")
2689
+ except OSError as e:
2690
+ if "GLIBCXX" in str(e):
2691
+ raise RuntimeError(
2692
+ f"Failed to load the shared library '{dll_path}'.\n"
2693
+ "The execution environment's libstdc++ runtime is older than the version the Warp library was built for.\n"
2694
+ "See https://nvidia.github.io/warp/_build/html/installation.html#conda-environments for details."
2695
+ ) from e
2696
+ else:
2697
+ raise RuntimeError(f"Failed to load the shared library '{dll_path}'") from e
2314
2698
  return dll
2315
2699
 
2316
2700
  def get_device(self, ident: Devicelike = None) -> Device:
@@ -2439,6 +2823,21 @@ def is_device_available(device):
2439
2823
  return device in get_devices()
2440
2824
 
2441
2825
 
2826
+ def is_cuda_driver_initialized() -> bool:
2827
+ """Returns ``True`` if the CUDA driver is initialized.
2828
+
2829
+ This is a stricter test than ``is_cuda_available()`` since a CUDA driver
2830
+ call to ``cuCtxGetCurrent`` is made, and the result is compared to
2831
+ `CUDA_SUCCESS`. Note that `CUDA_SUCCESS` is returned by ``cuCtxGetCurrent``
2832
+ even if there is no context bound to the calling CPU thread.
2833
+
2834
+ This can be helpful in cases in which ``cuInit()`` was called before a fork.
2835
+ """
2836
+ assert_initialized()
2837
+
2838
+ return runtime.core.cuda_driver_is_initialized()
2839
+
2840
+
2442
2841
  def get_devices() -> List[Device]:
2443
2842
  """Returns a list of devices supported in this environment."""
2444
2843
 
@@ -2749,7 +3148,7 @@ def full(
2749
3148
  elif na.ndim == 2:
2750
3149
  dtype = warp.types.matrix(na.shape, scalar_type)
2751
3150
  else:
2752
- raise ValueError(f"Values with more than two dimensions are not supported")
3151
+ raise ValueError("Values with more than two dimensions are not supported")
2753
3152
  else:
2754
3153
  raise ValueError(f"Invalid value type for Warp array: {value_type}")
2755
3154
 
@@ -2872,8 +3271,34 @@ def empty_like(
2872
3271
  return arr
2873
3272
 
2874
3273
 
2875
- def from_numpy(arr, dtype, device: Devicelike = None, requires_grad=False):
2876
- return warp.array(data=arr, dtype=dtype, device=device, requires_grad=requires_grad)
3274
+ def from_numpy(
3275
+ arr: np.ndarray,
3276
+ dtype: Optional[type] = None,
3277
+ shape: Optional[Sequence[int]] = None,
3278
+ device: Optional[Devicelike] = None,
3279
+ requires_grad: bool = False,
3280
+ ) -> warp.array:
3281
+ if dtype is None:
3282
+ base_type = warp.types.np_dtype_to_warp_type.get(arr.dtype)
3283
+ if base_type is None:
3284
+ raise RuntimeError("Unsupported NumPy data type '{}'.".format(arr.dtype))
3285
+
3286
+ dim_count = len(arr.shape)
3287
+ if dim_count == 2:
3288
+ dtype = warp.types.vector(length=arr.shape[1], dtype=base_type)
3289
+ elif dim_count == 3:
3290
+ dtype = warp.types.matrix(shape=(arr.shape[1], arr.shape[2]), dtype=base_type)
3291
+ else:
3292
+ dtype = base_type
3293
+
3294
+ return warp.array(
3295
+ data=arr,
3296
+ dtype=dtype,
3297
+ shape=shape,
3298
+ owner=False,
3299
+ device=device,
3300
+ requires_grad=requires_grad,
3301
+ )
2877
3302
 
2878
3303
 
2879
3304
  # given a kernel destination argument type and a value convert
@@ -2889,9 +3314,9 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
2889
3314
  # - in forward passes, array types have to match
2890
3315
  # - in backward passes, indexed array gradients are regular arrays
2891
3316
  if adjoint:
2892
- array_matches = type(value) == warp.array
3317
+ array_matches = isinstance(value, warp.array)
2893
3318
  else:
2894
- array_matches = type(value) == type(arg_type)
3319
+ array_matches = type(value) is type(arg_type)
2895
3320
 
2896
3321
  if not array_matches:
2897
3322
  adj = "adjoint " if adjoint else ""
@@ -2934,7 +3359,7 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
2934
3359
  # try constructing the required value from the argument (handles tuple / list, Gf.Vec3 case)
2935
3360
  try:
2936
3361
  return arg_type(value)
2937
- except:
3362
+ except Exception:
2938
3363
  raise ValueError(f"Failed to convert argument for param {arg_name} to {type_str(arg_type)}")
2939
3364
 
2940
3365
  elif isinstance(value, bool):
@@ -2943,27 +3368,35 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
2943
3368
  elif isinstance(value, arg_type):
2944
3369
  try:
2945
3370
  # try to pack as a scalar type
2946
- return arg_type._type_(value.value)
2947
- except:
3371
+ if arg_type is warp.types.float16:
3372
+ return arg_type._type_(warp.types.float_to_half_bits(value.value))
3373
+ else:
3374
+ return arg_type._type_(value.value)
3375
+ except Exception:
2948
3376
  raise RuntimeError(
2949
- f"Error launching kernel, unable to pack kernel parameter type {type(value)} for param {arg_name}, expected {arg_type}"
3377
+ "Error launching kernel, unable to pack kernel parameter type "
3378
+ f"{type(value)} for param {arg_name}, expected {arg_type}"
2950
3379
  )
2951
3380
 
2952
3381
  else:
2953
3382
  try:
2954
3383
  # try to pack as a scalar type
2955
- return arg_type._type_(value)
3384
+ if arg_type is warp.types.float16:
3385
+ return arg_type._type_(warp.types.float_to_half_bits(value))
3386
+ else:
3387
+ return arg_type._type_(value)
2956
3388
  except Exception as e:
2957
3389
  print(e)
2958
3390
  raise RuntimeError(
2959
- f"Error launching kernel, unable to pack kernel parameter type {type(value)} for param {arg_name}, expected {arg_type}"
3391
+ "Error launching kernel, unable to pack kernel parameter type "
3392
+ f"{type(value)} for param {arg_name}, expected {arg_type}"
2960
3393
  )
2961
3394
 
2962
3395
 
2963
3396
  # represents all data required for a kernel launch
2964
3397
  # so that launches can be replayed quickly, use `wp.launch(..., record_cmd=True)`
2965
3398
  class Launch:
2966
- def __init__(self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None):
3399
+ def __init__(self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None, max_blocks=0):
2967
3400
  # if not specified look up hooks
2968
3401
  if not hooks:
2969
3402
  module = kernel.module
@@ -3000,6 +3433,7 @@ class Launch:
3000
3433
  self.params_addr = params_addr
3001
3434
  self.device = device
3002
3435
  self.bounds = bounds
3436
+ self.max_blocks = max_blocks
3003
3437
 
3004
3438
  def set_dim(self, dim):
3005
3439
  self.bounds = warp.types.launch_bounds_t(dim)
@@ -3065,7 +3499,9 @@ class Launch:
3065
3499
  if self.device.is_cpu:
3066
3500
  self.hooks.forward(*self.params)
3067
3501
  else:
3068
- runtime.core.cuda_launch_kernel(self.device.context, self.hooks.forward, self.bounds.size, self.params_addr)
3502
+ runtime.core.cuda_launch_kernel(
3503
+ self.device.context, self.hooks.forward, self.bounds.size, self.max_blocks, self.params_addr
3504
+ )
3069
3505
 
3070
3506
 
3071
3507
  def launch(
@@ -3080,6 +3516,7 @@ def launch(
3080
3516
  adjoint=False,
3081
3517
  record_tape=True,
3082
3518
  record_cmd=False,
3519
+ max_blocks=0,
3083
3520
  ):
3084
3521
  """Launch a Warp kernel on the target device
3085
3522
 
@@ -3097,6 +3534,8 @@ def launch(
3097
3534
  adjoint: Whether to run forward or backward pass (typically use False)
3098
3535
  record_tape: When true the launch will be recorded the global wp.Tape() object when present
3099
3536
  record_cmd: When True the launch will be returned as a ``Launch`` command object, the launch will not occur until the user calls ``cmd.launch()``
3537
+ max_blocks: The maximum number of CUDA thread blocks to use. Only has an effect for CUDA kernel launches.
3538
+ If negative or zero, the maximum hardware value will be used.
3100
3539
  """
3101
3540
 
3102
3541
  assert_initialized()
@@ -3108,7 +3547,7 @@ def launch(
3108
3547
  device = runtime.get_device(device)
3109
3548
 
3110
3549
  # check function is a Kernel
3111
- if isinstance(kernel, Kernel) == False:
3550
+ if not isinstance(kernel, Kernel):
3112
3551
  raise RuntimeError("Error launching kernel, can only launch functions decorated with @wp.kernel.")
3113
3552
 
3114
3553
  # debugging aid
@@ -3190,7 +3629,9 @@ def launch(
3190
3629
  f"Failed to find backward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
3191
3630
  )
3192
3631
 
3193
- runtime.core.cuda_launch_kernel(device.context, hooks.backward, bounds.size, kernel_params)
3632
+ runtime.core.cuda_launch_kernel(
3633
+ device.context, hooks.backward, bounds.size, max_blocks, kernel_params
3634
+ )
3194
3635
 
3195
3636
  else:
3196
3637
  if hooks.forward is None:
@@ -3211,7 +3652,9 @@ def launch(
3211
3652
 
3212
3653
  else:
3213
3654
  # launch
3214
- runtime.core.cuda_launch_kernel(device.context, hooks.forward, bounds.size, kernel_params)
3655
+ runtime.core.cuda_launch_kernel(
3656
+ device.context, hooks.forward, bounds.size, max_blocks, kernel_params
3657
+ )
3215
3658
 
3216
3659
  try:
3217
3660
  runtime.verify_cuda_device(device)
@@ -3221,7 +3664,7 @@ def launch(
3221
3664
 
3222
3665
  # record on tape if one is active
3223
3666
  if runtime.tape and record_tape:
3224
- runtime.tape.record_launch(kernel, dim, inputs, outputs, device)
3667
+ runtime.tape.record_launch(kernel, dim, max_blocks, inputs, outputs, device)
3225
3668
 
3226
3669
 
3227
3670
  def synchronize():
@@ -3231,7 +3674,7 @@ def synchronize():
3231
3674
  or memory copies have completed.
3232
3675
  """
3233
3676
 
3234
- if is_cuda_available():
3677
+ if is_cuda_driver_initialized():
3235
3678
  # save the original context to avoid side effects
3236
3679
  saved_context = runtime.core.cuda_context_get_current()
3237
3680
 
@@ -3281,7 +3724,7 @@ def synchronize_stream(stream_or_device=None):
3281
3724
  runtime.core.cuda_stream_synchronize(stream.device.context, stream.cuda_stream)
3282
3725
 
3283
3726
 
3284
- def force_load(device: Union[Device, str] = None, modules: List[Module] = None):
3727
+ def force_load(device: Union[Device, str, List[Device], List[str]] = None, modules: List[Module] = None):
3285
3728
  """Force user-defined kernels to be compiled and loaded
3286
3729
 
3287
3730
  Args:
@@ -3289,12 +3732,14 @@ def force_load(device: Union[Device, str] = None, modules: List[Module] = None):
3289
3732
  modules: List of modules to load. If None, load all imported modules.
3290
3733
  """
3291
3734
 
3292
- if is_cuda_available():
3735
+ if is_cuda_driver_initialized():
3293
3736
  # save original context to avoid side effects
3294
3737
  saved_context = runtime.core.cuda_context_get_current()
3295
3738
 
3296
3739
  if device is None:
3297
3740
  devices = get_devices()
3741
+ elif isinstance(device, list):
3742
+ devices = [get_device(device_item) for device_item in device]
3298
3743
  else:
3299
3744
  devices = [get_device(device)]
3300
3745
 
@@ -3386,7 +3831,7 @@ def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
3386
3831
  return get_module(m.__name__).options
3387
3832
 
3388
3833
 
3389
- def capture_begin(device: Devicelike = None, stream=None, force_module_load=True):
3834
+ def capture_begin(device: Devicelike = None, stream=None, force_module_load=None):
3390
3835
  """Begin capture of a CUDA graph
3391
3836
 
3392
3837
  Captures all subsequent kernel launches and memory operations on CUDA devices.
@@ -3400,7 +3845,10 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=True
3400
3845
 
3401
3846
  """
3402
3847
 
3403
- if warp.config.verify_cuda == True:
3848
+ if force_module_load is None:
3849
+ force_module_load = warp.config.graph_capture_module_load_default
3850
+
3851
+ if warp.config.verify_cuda:
3404
3852
  raise RuntimeError("Cannot use CUDA error verification during graph capture")
3405
3853
 
3406
3854
  if stream is not None:
@@ -3415,6 +3863,9 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=True
3415
3863
 
3416
3864
  device.is_capturing = True
3417
3865
 
3866
+ # disable garbage collection to avoid older allocations getting collected during graph capture
3867
+ gc.disable()
3868
+
3418
3869
  with warp.ScopedStream(stream):
3419
3870
  runtime.core.cuda_graph_begin_capture(device.context)
3420
3871
 
@@ -3438,6 +3889,9 @@ def capture_end(device: Devicelike = None, stream=None) -> Graph:
3438
3889
 
3439
3890
  device.is_capturing = False
3440
3891
 
3892
+ # re-enable GC
3893
+ gc.enable()
3894
+
3441
3895
  if graph is None:
3442
3896
  raise RuntimeError(
3443
3897
  "Error occurred during CUDA graph capture. This could be due to an unintended allocation or CPU/GPU synchronization event."
@@ -3557,6 +4011,16 @@ def copy(
3557
4011
  if src_elem_size != dst_elem_size:
3558
4012
  raise RuntimeError("Incompatible array data types")
3559
4013
 
4014
+ # can't copy to/from fabric arrays of arrays, because they are jagged arrays of arbitrary lengths
4015
+ # TODO?
4016
+ if (
4017
+ isinstance(src, (warp.fabricarray, warp.indexedfabricarray))
4018
+ and src.ndim > 1
4019
+ or isinstance(dest, (warp.fabricarray, warp.indexedfabricarray))
4020
+ and dest.ndim > 1
4021
+ ):
4022
+ raise RuntimeError("Copying to/from Fabric arrays of arrays is not supported")
4023
+
3560
4024
  src_desc = src.__ctype__()
3561
4025
  dst_desc = dest.__ctype__()
3562
4026
  src_ptr = ctypes.pointer(src_desc)
@@ -3592,6 +4056,10 @@ def type_str(t):
3592
4056
  return f"Array[{type_str(t.dtype)}]"
3593
4057
  elif isinstance(t, warp.indexedarray):
3594
4058
  return f"IndexedArray[{type_str(t.dtype)}]"
4059
+ elif isinstance(t, warp.fabricarray):
4060
+ return f"FabricArray[{type_str(t.dtype)}]"
4061
+ elif isinstance(t, warp.indexedfabricarray):
4062
+ return f"IndexedFabricArray[{type_str(t.dtype)}]"
3595
4063
  elif hasattr(t, "_wp_generic_type_str_"):
3596
4064
  generic_type = t._wp_generic_type_str_
3597
4065
 
@@ -3618,7 +4086,7 @@ def type_str(t):
3618
4086
  return t.__name__
3619
4087
 
3620
4088
 
3621
- def print_function(f, file, noentry=False):
4089
+ def print_function(f, file, noentry=False): # pragma: no cover
3622
4090
  """Writes a function definition to a file for use in reST documentation
3623
4091
 
3624
4092
  Args:
@@ -3642,7 +4110,7 @@ def print_function(f, file, noentry=False):
3642
4110
  # todo: construct a default value for each of the functions args
3643
4111
  # so we can generate the return type for overloaded functions
3644
4112
  return_type = " -> " + type_str(f.value_func(None, None, None))
3645
- except:
4113
+ except Exception:
3646
4114
  pass
3647
4115
 
3648
4116
  print(f".. function:: {f.key}({args}){return_type}", file=file)
@@ -3663,7 +4131,7 @@ def print_function(f, file, noentry=False):
3663
4131
  return True
3664
4132
 
3665
4133
 
3666
- def print_builtins(file):
4134
+ def export_functions_rst(file): # pragma: no cover
3667
4135
  header = (
3668
4136
  "..\n"
3669
4137
  " Autogenerated File - Do not edit. Run build_docs.py to generate.\n"
@@ -3683,6 +4151,8 @@ def print_builtins(file):
3683
4151
 
3684
4152
  for t in warp.types.scalar_types:
3685
4153
  print(f".. class:: {t.__name__}", file=file)
4154
+ # Manually add wp.bool since it's inconvenient to add to wp.types.scalar_types:
4155
+ print(f".. class:: {warp.types.bool.__name__}", file=file)
3686
4156
 
3687
4157
  print("\n\nVector Types", file=file)
3688
4158
  print("------------", file=file)
@@ -3693,14 +4163,22 @@ def print_builtins(file):
3693
4163
  print("\nGeneric Types", file=file)
3694
4164
  print("-------------", file=file)
3695
4165
 
3696
- print(f".. class:: Int", file=file)
3697
- print(f".. class:: Float", file=file)
3698
- print(f".. class:: Scalar", file=file)
3699
- print(f".. class:: Vector", file=file)
3700
- print(f".. class:: Matrix", file=file)
3701
- print(f".. class:: Quaternion", file=file)
3702
- print(f".. class:: Transformation", file=file)
3703
- print(f".. class:: Array", file=file)
4166
+ print(".. class:: Int", file=file)
4167
+ print(".. class:: Float", file=file)
4168
+ print(".. class:: Scalar", file=file)
4169
+ print(".. class:: Vector", file=file)
4170
+ print(".. class:: Matrix", file=file)
4171
+ print(".. class:: Quaternion", file=file)
4172
+ print(".. class:: Transformation", file=file)
4173
+ print(".. class:: Array", file=file)
4174
+
4175
+ print("\nQuery Types", file=file)
4176
+ print("-------------", file=file)
4177
+ print(".. autoclass:: bvh_query_t", file=file)
4178
+ print(".. autoclass:: hash_grid_query_t", file=file)
4179
+ print(".. autoclass:: mesh_query_aabb_t", file=file)
4180
+ print(".. autoclass:: mesh_query_point_t", file=file)
4181
+ print(".. autoclass:: mesh_query_ray_t", file=file)
3704
4182
 
3705
4183
  # build dictionary of all functions by group
3706
4184
  groups = {}
@@ -3735,7 +4213,7 @@ def print_builtins(file):
3735
4213
  print(".. [1] Note: function gradients not implemented for backpropagation.", file=file)
3736
4214
 
3737
4215
 
3738
- def export_stubs(file):
4216
+ def export_stubs(file): # pragma: no cover
3739
4217
  """Generates stub file for auto-complete of builtin functions"""
3740
4218
 
3741
4219
  import textwrap
@@ -3767,6 +4245,8 @@ def export_stubs(file):
3767
4245
  print("Quaternion = Generic[Float]", file=file)
3768
4246
  print("Transformation = Generic[Float]", file=file)
3769
4247
  print("Array = Generic[DType]", file=file)
4248
+ print("FabricArray = Generic[DType]", file=file)
4249
+ print("IndexedFabricArray = Generic[DType]", file=file)
3770
4250
 
3771
4251
  # prepend __init__.py
3772
4252
  with open(os.path.join(os.path.dirname(file.name), "__init__.py")) as header_file:
@@ -3783,7 +4263,7 @@ def export_stubs(file):
3783
4263
 
3784
4264
  return_str = ""
3785
4265
 
3786
- if f.export == False or f.hidden == True: # or f.generic:
4266
+ if not f.export or f.hidden: # or f.generic:
3787
4267
  continue
3788
4268
 
3789
4269
  try:
@@ -3793,29 +4273,42 @@ def export_stubs(file):
3793
4273
  if return_type:
3794
4274
  return_str = " -> " + type_str(return_type)
3795
4275
 
3796
- except:
4276
+ except Exception:
3797
4277
  pass
3798
4278
 
3799
4279
  print("@over", file=file)
3800
4280
  print(f"def {f.key}({args}){return_str}:", file=file)
3801
- print(f' """', file=file)
4281
+ print(' """', file=file)
3802
4282
  print(textwrap.indent(text=f.doc, prefix=" "), file=file)
3803
- print(f' """', file=file)
3804
- print(f" ...\n\n", file=file)
4283
+ print(' """', file=file)
4284
+ print(" ...\n\n", file=file)
3805
4285
 
3806
4286
 
3807
- def export_builtins(file):
3808
- def ctype_str(t):
4287
+ def export_builtins(file: io.TextIOBase): # pragma: no cover
4288
+ def ctype_arg_str(t):
3809
4289
  if isinstance(t, int):
3810
4290
  return "int"
3811
4291
  elif isinstance(t, float):
3812
4292
  return "float"
4293
+ elif t in warp.types.vector_types:
4294
+ return f"{t.__name__}&"
3813
4295
  else:
3814
4296
  return t.__name__
3815
4297
 
4298
+ def ctype_ret_str(t):
4299
+ if isinstance(t, int):
4300
+ return "int"
4301
+ elif isinstance(t, float):
4302
+ return "float"
4303
+ else:
4304
+ return t.__name__
4305
+
4306
+ file.write("namespace wp {\n\n")
4307
+ file.write('extern "C" {\n\n')
4308
+
3816
4309
  for k, g in builtin_functions.items():
3817
4310
  for f in g.overloads:
3818
- if f.export == False or f.generic:
4311
+ if not f.export or f.generic:
3819
4312
  continue
3820
4313
 
3821
4314
  simple = True
@@ -3829,7 +4322,7 @@ def export_builtins(file):
3829
4322
  if not simple or f.variadic:
3830
4323
  continue
3831
4324
 
3832
- args = ", ".join(f"{ctype_str(v)} {k}" for k, v in f.input_types.items())
4325
+ args = ", ".join(f"{ctype_arg_str(v)} {k}" for k, v in f.input_types.items())
3833
4326
  params = ", ".join(f.input_types.keys())
3834
4327
 
3835
4328
  return_type = ""
@@ -3837,25 +4330,25 @@ def export_builtins(file):
3837
4330
  try:
3838
4331
  # todo: construct a default value for each of the functions args
3839
4332
  # so we can generate the return type for overloaded functions
3840
- return_type = ctype_str(f.value_func(None, None, None))
3841
- except:
4333
+ return_type = ctype_ret_str(f.value_func(None, None, None))
4334
+ except Exception:
3842
4335
  continue
3843
4336
 
3844
4337
  if return_type.startswith("Tuple"):
3845
4338
  continue
3846
4339
 
3847
4340
  if args == "":
3848
- print(
3849
- f"WP_API void {f.mangled_name}({return_type}* ret) {{ *ret = wp::{f.key}({params}); }}", file=file
3850
- )
4341
+ file.write(f"WP_API void {f.mangled_name}({return_type}* ret) {{ *ret = wp::{f.key}({params}); }}\n")
3851
4342
  elif return_type == "None":
3852
- print(f"WP_API void {f.mangled_name}({args}) {{ wp::{f.key}({params}); }}", file=file)
4343
+ file.write(f"WP_API void {f.mangled_name}({args}) {{ wp::{f.key}({params}); }}\n")
3853
4344
  else:
3854
- print(
3855
- f"WP_API void {f.mangled_name}({args}, {return_type}* ret) {{ *ret = wp::{f.key}({params}); }}",
3856
- file=file,
4345
+ file.write(
4346
+ f"WP_API void {f.mangled_name}({args}, {return_type}* ret) {{ *ret = wp::{f.key}({params}); }}\n"
3857
4347
  )
3858
4348
 
4349
+ file.write('\n} // extern "C"\n\n')
4350
+ file.write("} // namespace wp\n")
4351
+
3859
4352
 
3860
4353
  # initialize global runtime
3861
4354
  runtime = None