warp-lang 1.6.1__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  from __future__ import annotations
9
17
 
@@ -18,7 +26,7 @@ import re
18
26
  import sys
19
27
  import textwrap
20
28
  import types
21
- from typing import Any, Callable, Dict, Mapping, Optional, Sequence
29
+ from typing import Any, Callable, Dict, Mapping, Optional, Sequence, get_args, get_origin
22
30
 
23
31
  import warp.config
24
32
  from warp.types import *
@@ -49,7 +57,7 @@ class WarpCodegenKeyError(KeyError):
49
57
 
50
58
 
51
59
  # map operator to function name
52
- builtin_operators = {}
60
+ builtin_operators: Dict[type[ast.AST], str] = {}
53
61
 
54
62
  # see https://www.ics.uci.edu/~pattis/ICS-31/lectures/opexp.pdf for a
55
63
  # nice overview of python operators
@@ -114,16 +122,6 @@ def get_closure_cell_contents(obj):
114
122
  return None
115
123
 
116
124
 
117
- def get_type_origin(tp):
118
- # Compatible version of `typing.get_origin()` for Python 3.7 and older.
119
- return getattr(tp, "__origin__", None)
120
-
121
-
122
- def get_type_args(tp):
123
- # Compatible version of `typing.get_args()` for Python 3.7 and older.
124
- return getattr(tp, "__args__", ())
125
-
126
-
127
125
  def eval_annotations(annotations: Mapping[str, Any], obj: Any) -> Mapping[str, Any]:
128
126
  """Un-stringize annotations caused by `from __future__ import annotations` of PEP 563."""
129
127
  # Implementation backported from `inspect.get_annotations()` for Python 3.9 and older.
@@ -407,12 +405,14 @@ class StructInstance:
407
405
 
408
406
 
409
407
  class Struct:
410
- def __init__(self, cls, key, module):
408
+ hash: bytes
409
+
410
+ def __init__(self, cls: type, key: str, module: warp.context.Module):
411
411
  self.cls = cls
412
412
  self.module = module
413
413
  self.key = key
414
+ self.vars: Dict[str, Var] = {}
414
415
 
415
- self.vars = {}
416
416
  annotations = get_annotations(self.cls)
417
417
  for label, type in annotations.items():
418
418
  self.vars[label] = Var(label, type)
@@ -583,11 +583,11 @@ class Reference:
583
583
  self.value_type = value_type
584
584
 
585
585
 
586
- def is_reference(type):
586
+ def is_reference(type: Any) -> builtins.bool:
587
587
  return isinstance(type, Reference)
588
588
 
589
589
 
590
- def strip_reference(arg):
590
+ def strip_reference(arg: Any) -> Any:
591
591
  if is_reference(arg):
592
592
  return arg.value_type
593
593
  else:
@@ -615,7 +615,15 @@ def compute_type_str(base_name, template_params):
615
615
 
616
616
 
617
617
  class Var:
618
- def __init__(self, label, type, requires_grad=False, constant=None, prefix=True):
618
+ def __init__(
619
+ self,
620
+ label: str,
621
+ type: type,
622
+ requires_grad: builtins.bool = False,
623
+ constant: Optional[builtins.bool] = None,
624
+ prefix: builtins.bool = True,
625
+ relative_lineno: Optional[int] = None,
626
+ ):
619
627
  # convert built-in types to wp types
620
628
  if type == float:
621
629
  type = float32
@@ -638,11 +646,14 @@ class Var:
638
646
  # used to associate a view array Var with its parent array Var
639
647
  self.parent = None
640
648
 
649
+ # Used to associate the variable with the Python statement that resulted in it being created.
650
+ self.relative_lineno = relative_lineno
651
+
641
652
  def __str__(self):
642
653
  return self.label
643
654
 
644
655
  @staticmethod
645
- def type_to_ctype(t, value_type=False):
656
+ def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
646
657
  if is_array(t):
647
658
  if hasattr(t.dtype, "_wp_generic_type_str_"):
648
659
  dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
@@ -673,7 +684,7 @@ class Var:
673
684
  else:
674
685
  return f"wp::{t.__name__}"
675
686
 
676
- def ctype(self, value_type=False):
687
+ def ctype(self, value_type: builtins.bool = False) -> str:
677
688
  return Var.type_to_ctype(self.type, value_type)
678
689
 
679
690
  def emit(self, prefix: str = "var"):
@@ -795,7 +806,7 @@ def func_match_args(func, arg_types, kwarg_types):
795
806
  return True
796
807
 
797
808
 
798
- def get_arg_type(arg: Union[Var, Any]):
809
+ def get_arg_type(arg: Union[Var, Any]) -> type:
799
810
  if isinstance(arg, str):
800
811
  return str
801
812
 
@@ -811,7 +822,7 @@ def get_arg_type(arg: Union[Var, Any]):
811
822
  return type(arg)
812
823
 
813
824
 
814
- def get_arg_value(arg: Union[Var, Any]):
825
+ def get_arg_value(arg: Any) -> Any:
815
826
  if isinstance(arg, Sequence):
816
827
  return tuple(get_arg_value(x) for x in arg)
817
828
 
@@ -859,6 +870,9 @@ class Adjoint:
859
870
  "please save it on a file and use `importlib` if needed."
860
871
  ) from e
861
872
 
873
+ # Indicates where the function definition starts (excludes decorators)
874
+ adj.fun_def_lineno = None
875
+
862
876
  # get function source code
863
877
  adj.source = inspect.getsource(func)
864
878
  # ensures that indented class methods can be parsed as kernels
@@ -933,9 +947,6 @@ class Adjoint:
933
947
  # for unit testing errors being spit out from kernels.
934
948
  adj.skip_build = False
935
949
 
936
- # Collect the LTOIR required at link-time
937
- adj.ltoirs = []
938
-
939
950
  # allocate extra space for a function call that requires its
940
951
  # own shared memory space, we treat shared memory as a stack
941
952
  # where each function pushes and pops space off, the extra
@@ -1125,7 +1136,7 @@ class Adjoint:
1125
1136
  name = str(index)
1126
1137
 
1127
1138
  # allocate new variable
1128
- v = Var(name, type=type, constant=constant)
1139
+ v = Var(name, type=type, constant=constant, relative_lineno=adj.lineno)
1129
1140
 
1130
1141
  adj.variables.append(v)
1131
1142
 
@@ -1150,11 +1161,44 @@ class Adjoint:
1150
1161
 
1151
1162
  return var
1152
1163
 
1153
- # append a statement to the forward pass
1154
- def add_forward(adj, statement, replay=None, skip_replay=False):
1164
+ def get_line_directive(adj, statement: str, relative_lineno: Optional[int] = None) -> Optional[str]:
1165
+ """Get a line directive for the given statement.
1166
+
1167
+ Args:
1168
+ statement: The statement to get the line directive for.
1169
+ relative_lineno: The line number of the statement relative to the function.
1170
+
1171
+ Returns:
1172
+ A line directive for the given statement, or None if no line directive is needed.
1173
+ """
1174
+
1175
+ # lineinfo is enabled by default in debug mode regardless of the builder option, don't want to unnecessarily
1176
+ # emit line directives in generated code if it's not being compiled with line information
1177
+ lineinfo_enabled = (
1178
+ adj.builder_options.get("lineinfo", False) or adj.builder_options.get("mode", "release") == "debug"
1179
+ )
1180
+
1181
+ if relative_lineno is not None and lineinfo_enabled and warp.config.line_directives:
1182
+ is_comment = statement.strip().startswith("//")
1183
+ if not is_comment:
1184
+ line = relative_lineno + adj.fun_lineno
1185
+ # Convert backslashes to forward slashes for CUDA compatibility
1186
+ normalized_path = adj.filename.replace("\\", "/")
1187
+ return f'#line {line} "{normalized_path}"'
1188
+ return None
1189
+
1190
+ def add_forward(adj, statement: str, replay: Optional[str] = None, skip_replay: builtins.bool = False) -> None:
1191
+ """Append a statement to the forward pass."""
1192
+
1193
+ if line_directive := adj.get_line_directive(statement, adj.lineno):
1194
+ adj.blocks[-1].body_forward.append(line_directive)
1195
+
1155
1196
  adj.blocks[-1].body_forward.append(adj.indentation + statement)
1156
1197
 
1157
1198
  if not skip_replay:
1199
+ if line_directive:
1200
+ adj.blocks[-1].body_replay.append(line_directive)
1201
+
1158
1202
  if replay:
1159
1203
  # if custom replay specified then output it
1160
1204
  adj.blocks[-1].body_replay.append(adj.indentation + replay)
@@ -1163,9 +1207,14 @@ class Adjoint:
1163
1207
  adj.blocks[-1].body_replay.append(adj.indentation + statement)
1164
1208
 
1165
1209
  # append a statement to the reverse pass
1166
- def add_reverse(adj, statement):
1210
+ def add_reverse(adj, statement: str) -> None:
1211
+ """Append a statement to the reverse pass."""
1212
+
1167
1213
  adj.blocks[-1].body_reverse.append(adj.indentation + statement)
1168
1214
 
1215
+ if line_directive := adj.get_line_directive(statement, adj.lineno):
1216
+ adj.blocks[-1].body_reverse.append(line_directive)
1217
+
1169
1218
  def add_constant(adj, n):
1170
1219
  output = adj.add_var(type=type(n), constant=n)
1171
1220
  return output
@@ -1273,7 +1322,7 @@ class Adjoint:
1273
1322
 
1274
1323
  # Bind the positional and keyword arguments to the function's signature
1275
1324
  # in order to process them as Python does it.
1276
- bound_args = func.signature.bind(*args, **kwargs)
1325
+ bound_args: inspect.BoundArguments = func.signature.bind(*args, **kwargs)
1277
1326
 
1278
1327
  # Type args are the “compile time” argument values we get from codegen.
1279
1328
  # For example, when calling `wp.vec3f(...)` from within a kernel,
@@ -1616,6 +1665,8 @@ class Adjoint:
1616
1665
  adj.blocks[-1].body_reverse.extend(reversed(reverse))
1617
1666
 
1618
1667
  def emit_FunctionDef(adj, node):
1668
+ adj.fun_def_lineno = node.lineno
1669
+
1619
1670
  for f in node.body:
1620
1671
  # Skip variable creation for standalone constants, including docstrings
1621
1672
  if isinstance(f, ast.Expr) and isinstance(f.value, ast.Constant):
@@ -1680,7 +1731,7 @@ class Adjoint:
1680
1731
 
1681
1732
  if var1 != var2:
1682
1733
  # insert a phi function that selects var1, var2 based on cond
1683
- out = adj.add_builtin_call("select", [cond, var1, var2])
1734
+ out = adj.add_builtin_call("where", [cond, var2, var1])
1684
1735
  adj.symbols[sym] = out
1685
1736
 
1686
1737
  symbols_prev = adj.symbols.copy()
@@ -1704,7 +1755,7 @@ class Adjoint:
1704
1755
  if var1 != var2:
1705
1756
  # insert a phi function that selects var1, var2 based on cond
1706
1757
  # note the reversed order of vars since we want to use !cond as our select
1707
- out = adj.add_builtin_call("select", [cond, var2, var1])
1758
+ out = adj.add_builtin_call("where", [cond, var1, var2])
1708
1759
  adj.symbols[sym] = out
1709
1760
 
1710
1761
  def emit_Compare(adj, node):
@@ -1848,25 +1899,6 @@ class Adjoint:
1848
1899
  ) from e
1849
1900
  raise WarpCodegenAttributeError(f"Error, `{node.attr}` is not an attribute of '{aggregate}'") from e
1850
1901
 
1851
- def emit_String(adj, node):
1852
- # string constant
1853
- return adj.add_constant(node.s)
1854
-
1855
- def emit_Num(adj, node):
1856
- # lookup constant, if it has already been assigned then return existing var
1857
- key = (node.n, type(node.n))
1858
-
1859
- if key in adj.symbols:
1860
- return adj.symbols[key]
1861
- else:
1862
- out = adj.add_constant(node.n)
1863
- adj.symbols[key] = out
1864
- return out
1865
-
1866
- def emit_Ellipsis(adj, node):
1867
- # stubbed @wp.native_func
1868
- return
1869
-
1870
1902
  def emit_Assert(adj, node):
1871
1903
  # eval condition
1872
1904
  cond = adj.eval(node.test)
@@ -1878,24 +1910,11 @@ class Adjoint:
1878
1910
 
1879
1911
  adj.add_forward(f'assert(("{escaped_segment}",{cond.emit()}));')
1880
1912
 
1881
- def emit_NameConstant(adj, node):
1882
- if node.value:
1883
- return adj.add_constant(node.value)
1884
- elif node.value is None:
1885
- raise WarpCodegenTypeError("None type unsupported")
1886
- else:
1887
- return adj.add_constant(False)
1888
-
1889
1913
  def emit_Constant(adj, node):
1890
- if isinstance(node, ast.Str):
1891
- return adj.emit_String(node)
1892
- elif isinstance(node, ast.Num):
1893
- return adj.emit_Num(node)
1894
- elif isinstance(node, ast.Ellipsis):
1895
- return adj.emit_Ellipsis(node)
1914
+ if node.value is None:
1915
+ raise WarpCodegenTypeError("None type unsupported")
1896
1916
  else:
1897
- assert isinstance(node, ast.NameConstant) or isinstance(node, ast.Constant)
1898
- return adj.emit_NameConstant(node)
1917
+ return adj.add_constant(node.value)
1899
1918
 
1900
1919
  def emit_BinOp(adj, node):
1901
1920
  # evaluate binary operator arguments
@@ -1989,10 +2008,11 @@ class Adjoint:
1989
2008
  adj.end_while()
1990
2009
 
1991
2010
  def eval_num(adj, a):
1992
- if isinstance(a, ast.Num):
1993
- return True, a.n
1994
- if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Num):
1995
- return True, -a.operand.n
2011
+ if isinstance(a, ast.Constant):
2012
+ return True, a.value
2013
+ if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Constant):
2014
+ # Negative constant
2015
+ return True, -a.operand.value
1996
2016
 
1997
2017
  # try and resolve the expression to an object
1998
2018
  # e.g.: wp.constant in the globals scope
@@ -2522,8 +2542,8 @@ class Adjoint:
2522
2542
  f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
2523
2543
  )
2524
2544
  else:
2525
- if adj.builder_options.get("enable_backward", True):
2526
- out = adj.add_builtin_call("assign", [target, *indices, rhs])
2545
+ if warp.config.enable_vector_component_overwrites:
2546
+ out = adj.add_builtin_call("assign_copy", [target, *indices, rhs])
2527
2547
 
2528
2548
  # re-point target symbol to out var
2529
2549
  for id in adj.symbols:
@@ -2531,8 +2551,7 @@ class Adjoint:
2531
2551
  adj.symbols[id] = out
2532
2552
  break
2533
2553
  else:
2534
- attr = adj.add_builtin_call("index", [target, *indices])
2535
- adj.add_builtin_call("store", [attr, rhs])
2554
+ adj.add_builtin_call("assign_inplace", [target, *indices, rhs])
2536
2555
 
2537
2556
  else:
2538
2557
  raise WarpCodegenError(
@@ -2575,8 +2594,8 @@ class Adjoint:
2575
2594
  attr = adj.add_builtin_call("indexref", [aggregate, index])
2576
2595
  adj.add_builtin_call("store", [attr, rhs])
2577
2596
  else:
2578
- if adj.builder_options.get("enable_backward", True):
2579
- out = adj.add_builtin_call("assign", [aggregate, index, rhs])
2597
+ if warp.config.enable_vector_component_overwrites:
2598
+ out = adj.add_builtin_call("assign_copy", [aggregate, index, rhs])
2580
2599
 
2581
2600
  # re-point target symbol to out var
2582
2601
  for id in adj.symbols:
@@ -2584,8 +2603,7 @@ class Adjoint:
2584
2603
  adj.symbols[id] = out
2585
2604
  break
2586
2605
  else:
2587
- attr = adj.add_builtin_call("index", [aggregate, index])
2588
- adj.add_builtin_call("store", [attr, rhs])
2606
+ adj.add_builtin_call("assign_inplace", [aggregate, index, rhs])
2589
2607
 
2590
2608
  else:
2591
2609
  attr = adj.emit_Attribute(lhs)
@@ -2691,10 +2709,12 @@ class Adjoint:
2691
2709
 
2692
2710
  elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
2693
2711
  if isinstance(node.op, ast.Add):
2694
- adj.add_builtin_call("augassign_add", [target, *indices, rhs])
2712
+ adj.add_builtin_call("add_inplace", [target, *indices, rhs])
2695
2713
  elif isinstance(node.op, ast.Sub):
2696
- adj.add_builtin_call("augassign_sub", [target, *indices, rhs])
2714
+ adj.add_builtin_call("sub_inplace", [target, *indices, rhs])
2697
2715
  else:
2716
+ if warp.config.verbose:
2717
+ print(f"Warning: in-place op {node.op} is not differentiable")
2698
2718
  make_new_assign_statement()
2699
2719
  return
2700
2720
 
@@ -2724,9 +2744,6 @@ class Adjoint:
2724
2744
  ast.BoolOp: emit_BoolOp,
2725
2745
  ast.Name: emit_Name,
2726
2746
  ast.Attribute: emit_Attribute,
2727
- ast.Str: emit_String, # Deprecated in 3.8; use Constant
2728
- ast.Num: emit_Num, # Deprecated in 3.8; use Constant
2729
- ast.NameConstant: emit_NameConstant, # Deprecated in 3.8; use Constant
2730
2747
  ast.Constant: emit_Constant,
2731
2748
  ast.BinOp: emit_BinOp,
2732
2749
  ast.UnaryOp: emit_UnaryOp,
@@ -2736,14 +2753,13 @@ class Adjoint:
2736
2753
  ast.Continue: emit_Continue,
2737
2754
  ast.Expr: emit_Expr,
2738
2755
  ast.Call: emit_Call,
2739
- ast.Index: emit_Index, # Deprecated in 3.8; Use the index value directly instead.
2756
+ ast.Index: emit_Index, # Deprecated in 3.9
2740
2757
  ast.Subscript: emit_Subscript,
2741
2758
  ast.Assign: emit_Assign,
2742
2759
  ast.Return: emit_Return,
2743
2760
  ast.AugAssign: emit_AugAssign,
2744
2761
  ast.Tuple: emit_Tuple,
2745
2762
  ast.Pass: emit_Pass,
2746
- ast.Ellipsis: emit_Ellipsis,
2747
2763
  ast.Assert: emit_Assert,
2748
2764
  }
2749
2765
 
@@ -2939,12 +2955,16 @@ class Adjoint:
2939
2955
 
2940
2956
  # We want to replace the expression code in-place,
2941
2957
  # so reparse it to get the correct column info.
2942
- len_value_locs = []
2958
+ len_value_locs: List[Tuple[int, int, int]] = []
2943
2959
  expr_tree = ast.parse(static_code)
2944
2960
  assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
2945
2961
  expr_root = expr_tree.body[0].value
2946
2962
  for expr_node in ast.walk(expr_root):
2947
- if isinstance(expr_node, ast.Call) and expr_node.func.id == "len" and len(expr_node.args) == 1:
2963
+ if (
2964
+ isinstance(expr_node, ast.Call)
2965
+ and getattr(expr_node.func, "id", None) == "len"
2966
+ and len(expr_node.args) == 1
2967
+ ):
2948
2968
  len_expr = static_code[expr_node.col_offset : expr_node.end_col_offset]
2949
2969
  try:
2950
2970
  len_value = eval(len_expr, len_expr_ctx)
@@ -3102,9 +3122,9 @@ class Adjoint:
3102
3122
 
3103
3123
  local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed
3104
3124
 
3105
- constants = {}
3106
- types = {}
3107
- functions = {}
3125
+ constants: Dict[str, Any] = {}
3126
+ types: Dict[Union[Struct, type], Any] = {}
3127
+ functions: Dict[warp.context.Function, Any] = {}
3108
3128
 
3109
3129
  for node in ast.walk(adj.tree):
3110
3130
  if isinstance(node, ast.Name) and node.id not in local_variables:
@@ -3147,7 +3167,7 @@ class Adjoint:
3147
3167
  # code generation
3148
3168
 
3149
3169
  cpu_module_header = """
3150
- #define WP_TILE_BLOCK_DIM {tile_size}
3170
+ #define WP_TILE_BLOCK_DIM {block_dim}
3151
3171
  #define WP_NO_CRT
3152
3172
  #include "builtin.h"
3153
3173
 
@@ -3166,7 +3186,7 @@ cpu_module_header = """
3166
3186
  """
3167
3187
 
3168
3188
  cuda_module_header = """
3169
- #define WP_TILE_BLOCK_DIM {tile_size}
3189
+ #define WP_TILE_BLOCK_DIM {block_dim}
3170
3190
  #define WP_NO_CRT
3171
3191
  #include "builtin.h"
3172
3192
 
@@ -3189,6 +3209,7 @@ struct {name}
3189
3209
  {{
3190
3210
  {struct_body}
3191
3211
 
3212
+ {defaulted_constructor_def}
3192
3213
  CUDA_CALLABLE {name}({forward_args})
3193
3214
  {forward_initializers}
3194
3215
  {{
@@ -3231,53 +3252,53 @@ static void adj_{name}(
3231
3252
 
3232
3253
  cuda_forward_function_template = """
3233
3254
  // {filename}:{lineno}
3234
- static CUDA_CALLABLE {return_type} {name}(
3255
+ {line_directive}static CUDA_CALLABLE {return_type} {name}(
3235
3256
  {forward_args})
3236
3257
  {{
3237
- {forward_body}}}
3258
+ {forward_body}{line_directive}}}
3238
3259
 
3239
3260
  """
3240
3261
 
3241
3262
  cuda_reverse_function_template = """
3242
3263
  // {filename}:{lineno}
3243
- static CUDA_CALLABLE void adj_{name}(
3264
+ {line_directive}static CUDA_CALLABLE void adj_{name}(
3244
3265
  {reverse_args})
3245
3266
  {{
3246
- {reverse_body}}}
3267
+ {reverse_body}{line_directive}}}
3247
3268
 
3248
3269
  """
3249
3270
 
3250
3271
  cuda_kernel_template_forward = """
3251
3272
 
3252
- extern "C" __global__ void {name}_cuda_kernel_forward(
3273
+ {line_directive}extern "C" __global__ void {name}_cuda_kernel_forward(
3253
3274
  {forward_args})
3254
3275
  {{
3255
- for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3256
- _idx < dim.size;
3257
- _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3276
+ {line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3277
+ {line_directive} _idx < dim.size;
3278
+ {line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3258
3279
  {{
3259
3280
  // reset shared memory allocator
3260
- wp::tile_alloc_shared(0, true);
3281
+ {line_directive} wp::tile_alloc_shared(0, true);
3261
3282
 
3262
- {forward_body} }}
3263
- }}
3283
+ {forward_body}{line_directive} }}
3284
+ {line_directive}}}
3264
3285
 
3265
3286
  """
3266
3287
 
3267
3288
  cuda_kernel_template_backward = """
3268
3289
 
3269
- extern "C" __global__ void {name}_cuda_kernel_backward(
3290
+ {line_directive}extern "C" __global__ void {name}_cuda_kernel_backward(
3270
3291
  {reverse_args})
3271
3292
  {{
3272
- for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3273
- _idx < dim.size;
3274
- _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3293
+ {line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3294
+ {line_directive} _idx < dim.size;
3295
+ {line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3275
3296
  {{
3276
3297
  // reset shared memory allocator
3277
- wp::tile_alloc_shared(0, true);
3298
+ {line_directive} wp::tile_alloc_shared(0, true);
3278
3299
 
3279
- {reverse_body} }}
3280
- }}
3300
+ {reverse_body}{line_directive} }}
3301
+ {line_directive}}}
3281
3302
 
3282
3303
  """
3283
3304
 
@@ -3307,10 +3328,17 @@ extern "C" {{
3307
3328
  WP_API void {name}_cpu_forward(
3308
3329
  {forward_args})
3309
3330
  {{
3310
- for (size_t task_index = 0; task_index < dim.size; ++task_index)
3331
+ for (size_t task_index = 0; task_index < dim.size; ++task_index)
3311
3332
  {{
3333
+ // init shared memory allocator
3334
+ wp::tile_alloc_shared(0, true);
3335
+
3312
3336
  {name}_cpu_kernel_forward(
3313
3337
  {forward_params});
3338
+
3339
+ // check shared memory allocator
3340
+ wp::tile_alloc_shared(0, false, true);
3341
+
3314
3342
  }}
3315
3343
  }}
3316
3344
 
@@ -3327,8 +3355,14 @@ WP_API void {name}_cpu_backward(
3327
3355
  {{
3328
3356
  for (size_t task_index = 0; task_index < dim.size; ++task_index)
3329
3357
  {{
3358
+ // initialize shared memory allocator
3359
+ wp::tile_alloc_shared(0, true);
3360
+
3330
3361
  {name}_cpu_kernel_backward(
3331
3362
  {reverse_params});
3363
+
3364
+ // check shared memory allocator
3365
+ wp::tile_alloc_shared(0, false, true);
3332
3366
  }}
3333
3367
  }}
3334
3368
 
@@ -3410,7 +3444,7 @@ def indent(args, stops=1):
3410
3444
 
3411
3445
 
3412
3446
  # generates a C function name based on the python function name
3413
- def make_full_qualified_name(func):
3447
+ def make_full_qualified_name(func: Union[str, Callable]) -> str:
3414
3448
  if not isinstance(func, str):
3415
3449
  func = func.__qualname__
3416
3450
  return re.sub("[^0-9a-zA-Z_]+", "", func.replace(".", "__"))
@@ -3440,7 +3474,8 @@ def codegen_struct(struct, device="cpu", indent_size=4):
3440
3474
  # forward args
3441
3475
  for label, var in struct.vars.items():
3442
3476
  var_ctype = var.ctype()
3443
- forward_args.append(f"{var_ctype} const& {label} = {{}}")
3477
+ default_arg_def = " = {}" if forward_args else ""
3478
+ forward_args.append(f"{var_ctype} const& {label}{default_arg_def}")
3444
3479
  reverse_args.append(f"{var_ctype} const&")
3445
3480
 
3446
3481
  namespace = "wp::" if var_ctype.startswith("wp::") or var_ctype == "bool" else ""
@@ -3464,6 +3499,9 @@ def codegen_struct(struct, device="cpu", indent_size=4):
3464
3499
 
3465
3500
  reverse_args.append(name + " & adj_ret")
3466
3501
 
3502
+ # explicitly defaulted default constructor if no default constructor has been defined
3503
+ defaulted_constructor_def = f"{name}() = default;" if forward_args else ""
3504
+
3467
3505
  return struct_template.format(
3468
3506
  name=name,
3469
3507
  struct_body="".join([indent_block + l for l in body]),
@@ -3473,6 +3511,7 @@ def codegen_struct(struct, device="cpu", indent_size=4):
3473
3511
  reverse_body="".join(reverse_body),
3474
3512
  prefix_add_body="".join(prefix_add_body),
3475
3513
  atomic_add_body="".join(atomic_add_body),
3514
+ defaulted_constructor_def=defaulted_constructor_def,
3476
3515
  )
3477
3516
 
3478
3517
 
@@ -3502,6 +3541,9 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
3502
3541
  else:
3503
3542
  lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
3504
3543
 
3544
+ if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
3545
+ lines.insert(-1, f"{line_directive}\n")
3546
+
3505
3547
  # forward pass
3506
3548
  lines += ["//---------\n"]
3507
3549
  lines += ["// forward\n"]
@@ -3509,7 +3551,7 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
3509
3551
  for f in adj.blocks[0].body_forward:
3510
3552
  lines += [f + "\n"]
3511
3553
 
3512
- return "".join([indent_block + l for l in lines])
3554
+ return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
3513
3555
 
3514
3556
 
3515
3557
  def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
@@ -3539,6 +3581,9 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3539
3581
  else:
3540
3582
  lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
3541
3583
 
3584
+ if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
3585
+ lines.insert(-1, f"{line_directive}\n")
3586
+
3542
3587
  # dual vars
3543
3588
  lines += ["//---------\n"]
3544
3589
  lines += ["// dual vars\n"]
@@ -3559,6 +3604,9 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3559
3604
  else:
3560
3605
  lines += [f"{ctype} {name} = {{}};\n"]
3561
3606
 
3607
+ if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
3608
+ lines.insert(-1, f"{line_directive}\n")
3609
+
3562
3610
  # forward pass
3563
3611
  lines += ["//---------\n"]
3564
3612
  lines += ["// forward\n"]
@@ -3579,7 +3627,7 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3579
3627
  else:
3580
3628
  lines += ["return;\n"]
3581
3629
 
3582
- return "".join([indent_block + l for l in lines])
3630
+ return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
3583
3631
 
3584
3632
 
3585
3633
  def codegen_func(adj, c_func_name: str, device="cpu", options=None):
@@ -3587,11 +3635,11 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3587
3635
  options = {}
3588
3636
 
3589
3637
  if adj.return_var is not None and "return" in adj.arg_types:
3590
- if get_type_origin(adj.arg_types["return"]) is tuple:
3591
- if len(get_type_args(adj.arg_types["return"])) != len(adj.return_var):
3638
+ if get_origin(adj.arg_types["return"]) is tuple:
3639
+ if len(get_args(adj.arg_types["return"])) != len(adj.return_var):
3592
3640
  raise WarpCodegenError(
3593
3641
  f"The function `{adj.fun_name}` has its return type "
3594
- f"annotated as a tuple of {len(get_type_args(adj.arg_types['return']))} elements "
3642
+ f"annotated as a tuple of {len(get_args(adj.arg_types['return']))} elements "
3595
3643
  f"but the code returns {len(adj.return_var)} values."
3596
3644
  )
3597
3645
  elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var)):
@@ -3600,7 +3648,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3600
3648
  f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
3601
3649
  f"but the code returns a tuple with types `({', '.join(warp.context.type_str(x.type) for x in adj.return_var)})`."
3602
3650
  )
3603
- elif len(adj.return_var) > 1 and get_type_origin(adj.arg_types["return"]) is not tuple:
3651
+ elif len(adj.return_var) > 1 and get_origin(adj.arg_types["return"]) is not tuple:
3604
3652
  raise WarpCodegenError(
3605
3653
  f"The function `{adj.fun_name}` has its return type "
3606
3654
  f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
@@ -3613,6 +3661,13 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3613
3661
  f"but the code returns a value of type `{warp.context.type_str(adj.return_var[0].type)}`."
3614
3662
  )
3615
3663
 
3664
+ # Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
3665
+ # This is used as a catch-all C-to-Python source line mapping for any code that does not have
3666
+ # a direct mapping to a Python source line.
3667
+ func_line_directive = ""
3668
+ if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
3669
+ func_line_directive = f"{line_directive}\n"
3670
+
3616
3671
  # forward header
3617
3672
  if adj.return_var is not None and len(adj.return_var) == 1:
3618
3673
  return_type = adj.return_var[0].ctype()
@@ -3676,6 +3731,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3676
3731
  forward_body=forward_body,
3677
3732
  filename=adj.filename,
3678
3733
  lineno=adj.fun_lineno,
3734
+ line_directive=func_line_directive,
3679
3735
  )
3680
3736
 
3681
3737
  if not adj.skip_reverse_codegen:
@@ -3694,6 +3750,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3694
3750
  reverse_body=reverse_body,
3695
3751
  filename=adj.filename,
3696
3752
  lineno=adj.fun_lineno,
3753
+ line_directive=func_line_directive,
3697
3754
  )
3698
3755
 
3699
3756
  return s
@@ -3736,6 +3793,7 @@ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
3736
3793
  forward_body=snippet,
3737
3794
  filename=adj.filename,
3738
3795
  lineno=adj.fun_lineno,
3796
+ line_directive="",
3739
3797
  )
3740
3798
 
3741
3799
  if replay_snippet is not None:
@@ -3746,6 +3804,7 @@ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
3746
3804
  forward_body=replay_snippet,
3747
3805
  filename=adj.filename,
3748
3806
  lineno=adj.fun_lineno,
3807
+ line_directive="",
3749
3808
  )
3750
3809
 
3751
3810
  if adj_snippet:
@@ -3761,6 +3820,7 @@ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
3761
3820
  reverse_body=reverse_body,
3762
3821
  filename=adj.filename,
3763
3822
  lineno=adj.fun_lineno,
3823
+ line_directive="",
3764
3824
  )
3765
3825
 
3766
3826
  return s
@@ -3773,6 +3833,13 @@ def codegen_kernel(kernel, device, options):
3773
3833
 
3774
3834
  adj = kernel.adj
3775
3835
 
3836
+ # Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
3837
+ # This is used as a catch-all C-to-Python source line mapping for any code that does not have
3838
+ # a direct mapping to a Python source line.
3839
+ func_line_directive = ""
3840
+ if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
3841
+ func_line_directive = f"{line_directive}\n"
3842
+
3776
3843
  if device == "cpu":
3777
3844
  template_forward = cpu_kernel_template_forward
3778
3845
  template_backward = cpu_kernel_template_backward
@@ -3800,6 +3867,7 @@ def codegen_kernel(kernel, device, options):
3800
3867
  {
3801
3868
  "forward_args": indent(forward_args),
3802
3869
  "forward_body": forward_body,
3870
+ "line_directive": func_line_directive,
3803
3871
  }
3804
3872
  )
3805
3873
  template += template_forward