warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/codegen.py ADDED
@@ -0,0 +1,3969 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import ast
19
+ import builtins
20
+ import ctypes
21
+ import functools
22
+ import hashlib
23
+ import inspect
24
+ import math
25
+ import re
26
+ import sys
27
+ import textwrap
28
+ import types
29
+ from typing import Any, Callable, Dict, Mapping, Optional, Sequence, get_args, get_origin
30
+
31
+ import warp.config
32
+ from warp.types import *
33
+
34
+ # used as a globally accessible copy
35
+ # of current compile options (block_dim) etc
36
+ options = {}
37
+
38
+
39
+ class WarpCodegenError(RuntimeError):
40
+ def __init__(self, message):
41
+ super().__init__(message)
42
+
43
+
44
+ class WarpCodegenTypeError(TypeError):
45
+ def __init__(self, message):
46
+ super().__init__(message)
47
+
48
+
49
+ class WarpCodegenAttributeError(AttributeError):
50
+ def __init__(self, message):
51
+ super().__init__(message)
52
+
53
+
54
+ class WarpCodegenKeyError(KeyError):
55
+ def __init__(self, message):
56
+ super().__init__(message)
57
+
58
+
59
+ # map operator to function name
60
+ builtin_operators: Dict[type[ast.AST], str] = {}
61
+
62
+ # see https://www.ics.uci.edu/~pattis/ICS-31/lectures/opexp.pdf for a
63
+ # nice overview of python operators
64
+
65
+ builtin_operators[ast.Add] = "add"
66
+ builtin_operators[ast.Sub] = "sub"
67
+ builtin_operators[ast.Mult] = "mul"
68
+ builtin_operators[ast.MatMult] = "mul"
69
+ builtin_operators[ast.Div] = "div"
70
+ builtin_operators[ast.FloorDiv] = "floordiv"
71
+ builtin_operators[ast.Pow] = "pow"
72
+ builtin_operators[ast.Mod] = "mod"
73
+ builtin_operators[ast.UAdd] = "pos"
74
+ builtin_operators[ast.USub] = "neg"
75
+ builtin_operators[ast.Not] = "unot"
76
+
77
+ builtin_operators[ast.Gt] = ">"
78
+ builtin_operators[ast.Lt] = "<"
79
+ builtin_operators[ast.GtE] = ">="
80
+ builtin_operators[ast.LtE] = "<="
81
+ builtin_operators[ast.Eq] = "=="
82
+ builtin_operators[ast.NotEq] = "!="
83
+
84
+ builtin_operators[ast.BitAnd] = "bit_and"
85
+ builtin_operators[ast.BitOr] = "bit_or"
86
+ builtin_operators[ast.BitXor] = "bit_xor"
87
+ builtin_operators[ast.Invert] = "invert"
88
+ builtin_operators[ast.LShift] = "lshift"
89
+ builtin_operators[ast.RShift] = "rshift"
90
+
91
+ comparison_chain_strings = [
92
+ builtin_operators[ast.Gt],
93
+ builtin_operators[ast.Lt],
94
+ builtin_operators[ast.LtE],
95
+ builtin_operators[ast.GtE],
96
+ builtin_operators[ast.Eq],
97
+ builtin_operators[ast.NotEq],
98
+ ]
99
+
100
+
101
+ def values_check_equal(a, b):
102
+ if isinstance(a, Sequence) and isinstance(b, Sequence):
103
+ if len(a) != len(b):
104
+ return False
105
+
106
+ return all(x == y for x, y in zip(a, b))
107
+
108
+ return a == b
109
+
110
+
111
+ def op_str_is_chainable(op: str) -> builtins.bool:
112
+ return op in comparison_chain_strings
113
+
114
+
115
+ def get_closure_cell_contents(obj):
116
+ """Retrieve a closure's cell contents or `None` if it's empty."""
117
+ try:
118
+ return obj.cell_contents
119
+ except ValueError:
120
+ pass
121
+
122
+ return None
123
+
124
+
125
+ def eval_annotations(annotations: Mapping[str, Any], obj: Any) -> Mapping[str, Any]:
126
+ """Un-stringize annotations caused by `from __future__ import annotations` of PEP 563."""
127
+ # Implementation backported from `inspect.get_annotations()` for Python 3.9 and older.
128
+ if not annotations:
129
+ return {}
130
+
131
+ if not any(isinstance(x, str) for x in annotations.values()):
132
+ # No annotation to un-stringize.
133
+ return annotations
134
+
135
+ if isinstance(obj, type):
136
+ # class
137
+ globals = {}
138
+ module_name = getattr(obj, "__module__", None)
139
+ if module_name:
140
+ module = sys.modules.get(module_name, None)
141
+ if module:
142
+ globals = getattr(module, "__dict__", {})
143
+ locals = dict(vars(obj))
144
+ unwrap = obj
145
+ elif isinstance(obj, types.ModuleType):
146
+ # module
147
+ globals = obj.__dict__
148
+ locals = {}
149
+ unwrap = None
150
+ elif callable(obj):
151
+ # function
152
+ globals = getattr(obj, "__globals__", {})
153
+ # Capture the variables from the surrounding scope.
154
+ closure_vars = zip(
155
+ obj.__code__.co_freevars, tuple(get_closure_cell_contents(x) for x in (obj.__closure__ or ()))
156
+ )
157
+ locals = {k: v for k, v in closure_vars if v is not None}
158
+ unwrap = obj
159
+ else:
160
+ raise TypeError(f"{obj!r} is not a module, class, or callable.")
161
+
162
+ if unwrap is not None:
163
+ while True:
164
+ if hasattr(unwrap, "__wrapped__"):
165
+ unwrap = unwrap.__wrapped__
166
+ continue
167
+ if isinstance(unwrap, functools.partial):
168
+ unwrap = unwrap.func
169
+ continue
170
+ break
171
+ if hasattr(unwrap, "__globals__"):
172
+ globals = unwrap.__globals__
173
+
174
+ # "Inject" type parameters into the local namespace
175
+ # (unless they are shadowed by assignments *in* the local namespace),
176
+ # as a way of emulating annotation scopes when calling `eval()`
177
+ type_params = getattr(obj, "__type_params__", ())
178
+ if type_params:
179
+ locals = {param.__name__: param for param in type_params} | locals
180
+
181
+ return {k: v if not isinstance(v, str) else eval(v, globals, locals) for k, v in annotations.items()}
182
+
183
+
184
+ def get_annotations(obj: Any) -> Mapping[str, Any]:
185
+ """Same as `inspect.get_annotations()` but always returning un-stringized annotations."""
186
+ # This backports `inspect.get_annotations()` for Python 3.9 and older.
187
+ # See https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
188
+ if isinstance(obj, type):
189
+ annotations = obj.__dict__.get("__annotations__", {})
190
+ else:
191
+ annotations = getattr(obj, "__annotations__", {})
192
+
193
+ # Evaluating annotations can be done using the `eval_str` parameter with
194
+ # the official function from the `inspect` module.
195
+ return eval_annotations(annotations, obj)
196
+
197
+
198
+ def get_full_arg_spec(func: Callable) -> inspect.FullArgSpec:
199
+ """Same as `inspect.getfullargspec()` but always returning un-stringized annotations."""
200
+ # See https://docs.python.org/3/howto/annotations.html#manually-un-stringizing-stringized-annotations
201
+ spec = inspect.getfullargspec(func)
202
+ return spec._replace(annotations=eval_annotations(spec.annotations, func))
203
+
204
+
205
+ def struct_instance_repr_recursive(inst: StructInstance, depth: int) -> str:
206
+ indent = "\t"
207
+
208
+ # handle empty structs
209
+ if len(inst._cls.vars) == 0:
210
+ return f"{inst._cls.key}()"
211
+
212
+ lines = []
213
+ lines.append(f"{inst._cls.key}(")
214
+
215
+ for field_name, _ in inst._cls.ctype._fields_:
216
+ field_value = getattr(inst, field_name, None)
217
+
218
+ if isinstance(field_value, StructInstance):
219
+ field_value = struct_instance_repr_recursive(field_value, depth + 1)
220
+
221
+ lines.append(f"{indent * (depth + 1)}{field_name}={field_value},")
222
+
223
+ lines.append(f"{indent * depth})")
224
+ return "\n".join(lines)
225
+
226
+
227
+ class StructInstance:
228
+ def __init__(self, cls: Struct, ctype):
229
+ super().__setattr__("_cls", cls)
230
+
231
+ # maintain a c-types object for the top-level instance the struct
232
+ if not ctype:
233
+ super().__setattr__("_ctype", cls.ctype())
234
+ else:
235
+ super().__setattr__("_ctype", ctype)
236
+
237
+ # create Python attributes for each of the struct's variables
238
+ for field, var in cls.vars.items():
239
+ if isinstance(var.type, warp.codegen.Struct):
240
+ self.__dict__[field] = StructInstance(var.type, getattr(self._ctype, field))
241
+ elif isinstance(var.type, warp.types.array):
242
+ self.__dict__[field] = None
243
+ else:
244
+ self.__dict__[field] = var.type()
245
+
246
+ def __getattribute__(self, name):
247
+ cls = super().__getattribute__("_cls")
248
+ if name == "native_name":
249
+ return cls.native_name
250
+
251
+ var = cls.vars.get(name)
252
+ if var is not None:
253
+ if isinstance(var.type, type) and issubclass(var.type, ctypes.Array):
254
+ # Each field stored in a `StructInstance` is exposed as
255
+ # a standard Python attribute but also has a `ctypes`
256
+ # equivalent that is being updated in `__setattr__`.
257
+ # However, when assigning in place an object such as a vec/mat
258
+ # (e.g.: `my_struct.my_vec[0] = 1.23`), the `__setattr__` method
259
+ # from `StructInstance` isn't called, and the synchronization
260
+ # mechanism has no chance of updating the underlying ctype data.
261
+ # As a workaround, we catch here all attempts at accessing such
262
+ # objects and directly return their underlying ctype since
263
+ # the Python-facing Warp vectors and matrices are implemented
264
+ # using `ctypes.Array` anyways.
265
+ return getattr(self._ctype, name)
266
+
267
+ return super().__getattribute__(name)
268
+
269
+ def __setattr__(self, name, value):
270
+ if name not in self._cls.vars:
271
+ raise RuntimeError(f"Trying to set Warp struct attribute that does not exist {name}")
272
+
273
+ var = self._cls.vars[name]
274
+
275
+ # update our ctype flat copy
276
+ if isinstance(var.type, array):
277
+ if value is None:
278
+ # create array with null pointer
279
+ setattr(self._ctype, name, array_t())
280
+ else:
281
+ # wp.array
282
+ assert isinstance(value, array)
283
+ assert types_equal(value.dtype, var.type.dtype), (
284
+ f"assign to struct member variable {name} failed, expected type {type_repr(var.type.dtype)}, got type {type_repr(value.dtype)}"
285
+ )
286
+ setattr(self._ctype, name, value.__ctype__())
287
+
288
+ elif isinstance(var.type, Struct):
289
+ # assign structs by-value, otherwise we would have problematic cases transferring ownership
290
+ # of the underlying ctypes data between shared Python struct instances
291
+
292
+ if not isinstance(value, StructInstance):
293
+ raise RuntimeError(
294
+ f"Trying to assign a non-structure value to a struct attribute with type: {self._cls.key}"
295
+ )
296
+
297
+ # destination attribution on self
298
+ dest = getattr(self, name)
299
+
300
+ if dest._cls.key is not value._cls.key:
301
+ raise RuntimeError(
302
+ f"Trying to assign a structure of type {value._cls.key} to an attribute of {self._cls.key}"
303
+ )
304
+
305
+ # update all nested ctype vars by deep copy
306
+ for n in dest._cls.vars:
307
+ setattr(dest, n, getattr(value, n))
308
+
309
+ # early return to avoid updating our Python StructInstance
310
+ return
311
+
312
+ elif issubclass(var.type, ctypes.Array):
313
+ # vector/matrix type, e.g. vec3
314
+ if value is None:
315
+ setattr(self._ctype, name, var.type())
316
+ elif types_equal(type(value), var.type):
317
+ setattr(self._ctype, name, value)
318
+ else:
319
+ # conversion from list/tuple, ndarray, etc.
320
+ setattr(self._ctype, name, var.type(value))
321
+
322
+ else:
323
+ # primitive type
324
+ if value is None:
325
+ # zero initialize
326
+ setattr(self._ctype, name, var.type._type_())
327
+ else:
328
+ if hasattr(value, "_type_"):
329
+ # assigning warp type value (e.g.: wp.float32)
330
+ value = value.value
331
+ # float16 needs conversion to uint16 bits
332
+ if var.type == warp.float16:
333
+ setattr(self._ctype, name, float_to_half_bits(value))
334
+ else:
335
+ setattr(self._ctype, name, value)
336
+
337
+ # update Python instance
338
+ super().__setattr__(name, value)
339
+
340
+ def __ctype__(self):
341
+ return self._ctype
342
+
343
+ def __repr__(self):
344
+ return struct_instance_repr_recursive(self, 0)
345
+
346
+ def to(self, device):
347
+ """Copies this struct with all array members moved onto the given device.
348
+
349
+ Arrays already living on the desired device are referenced as-is, while
350
+ arrays being moved are copied.
351
+ """
352
+ out = self._cls()
353
+ stack = [(self, out, k, v) for k, v in self._cls.vars.items()]
354
+ while stack:
355
+ src, dst, name, var = stack.pop()
356
+ value = getattr(src, name)
357
+ if isinstance(var.type, array):
358
+ # array_t
359
+ setattr(dst, name, value.to(device))
360
+ elif isinstance(var.type, Struct):
361
+ # nested struct
362
+ new_struct = value._cls()
363
+ setattr(dst, name, new_struct)
364
+ # The call to `setattr()` just above makes a copy of `new_struct`
365
+ # so we need to reference that new instance of the struct.
366
+ new_struct = getattr(dst, name)
367
+ stack.extend((value, new_struct, k, v) for k, v in value._cls.vars.items())
368
+ else:
369
+ setattr(dst, name, value)
370
+
371
+ return out
372
+
373
+ # type description used in numpy structured arrays
374
+ def numpy_dtype(self):
375
+ return self._cls.numpy_dtype()
376
+
377
+ # value usable in numpy structured arrays of .numpy_dtype(), e.g. (42, 13.37, [1.0, 2.0, 3.0])
378
+ def numpy_value(self):
379
+ npvalue = []
380
+ for name, var in self._cls.vars.items():
381
+ # get the attribute value
382
+ value = getattr(self._ctype, name)
383
+
384
+ if isinstance(var.type, array):
385
+ # array_t
386
+ npvalue.append(value.numpy_value())
387
+ elif isinstance(var.type, Struct):
388
+ # nested struct
389
+ npvalue.append(value.numpy_value())
390
+ elif issubclass(var.type, ctypes.Array):
391
+ if len(var.type._shape_) == 1:
392
+ # vector
393
+ npvalue.append(list(value))
394
+ else:
395
+ # matrix
396
+ npvalue.append([list(row) for row in value])
397
+ else:
398
+ # scalar
399
+ if var.type == warp.float16:
400
+ npvalue.append(half_bits_to_float(value))
401
+ else:
402
+ npvalue.append(value)
403
+
404
+ return tuple(npvalue)
405
+
406
+
407
+ class Struct:
408
+ hash: bytes
409
+
410
+ def __init__(self, cls: type, key: str, module: warp.context.Module):
411
+ self.cls = cls
412
+ self.module = module
413
+ self.key = key
414
+ self.vars: Dict[str, Var] = {}
415
+
416
+ annotations = get_annotations(self.cls)
417
+ for label, type in annotations.items():
418
+ self.vars[label] = Var(label, type)
419
+
420
+ fields = []
421
+ for label, var in self.vars.items():
422
+ if isinstance(var.type, array):
423
+ fields.append((label, array_t))
424
+ elif isinstance(var.type, Struct):
425
+ fields.append((label, var.type.ctype))
426
+ elif issubclass(var.type, ctypes.Array):
427
+ fields.append((label, var.type))
428
+ else:
429
+ # HACK: fp16 requires conversion functions from warp.so
430
+ if var.type is warp.float16:
431
+ warp.init()
432
+ fields.append((label, var.type._type_))
433
+
434
+ class StructType(ctypes.Structure):
435
+ # if struct is empty, add a dummy field to avoid launch errors on CPU device ("ffi_prep_cif failed")
436
+ _fields_ = fields or [("_dummy_", ctypes.c_byte)]
437
+
438
+ self.ctype = StructType
439
+
440
+ # Compute the hash. We can cache the hash because it's static, even with nested structs.
441
+ # All field types are specified in the annotations, so they're resolved at declaration time.
442
+ ch = hashlib.sha256()
443
+
444
+ ch.update(bytes(self.key, "utf-8"))
445
+
446
+ for name, type_hint in annotations.items():
447
+ s = f"{name}:{warp.types.get_type_code(type_hint)}"
448
+ ch.update(bytes(s, "utf-8"))
449
+
450
+ # recurse on nested structs
451
+ if isinstance(type_hint, Struct):
452
+ ch.update(type_hint.hash)
453
+
454
+ self.hash = ch.digest()
455
+
456
+ # generate unique identifier for structs in native code
457
+ hash_suffix = f"{self.hash.hex()[:8]}"
458
+ self.native_name = f"{self.key}_{hash_suffix}"
459
+
460
+ # create default constructor (zero-initialize)
461
+ self.default_constructor = warp.context.Function(
462
+ func=None,
463
+ key=self.native_name,
464
+ namespace="",
465
+ value_func=lambda *_: self,
466
+ input_types={},
467
+ initializer_list_func=lambda *_: False,
468
+ native_func=self.native_name,
469
+ )
470
+
471
+ # build a constructor that takes each param as a value
472
+ input_types = {label: var.type for label, var in self.vars.items()}
473
+
474
+ self.value_constructor = warp.context.Function(
475
+ func=None,
476
+ key=self.native_name,
477
+ namespace="",
478
+ value_func=lambda *_: self,
479
+ input_types=input_types,
480
+ initializer_list_func=lambda *_: False,
481
+ native_func=self.native_name,
482
+ )
483
+
484
+ self.default_constructor.add_overload(self.value_constructor)
485
+
486
+ if module:
487
+ module.register_struct(self)
488
+
489
+ def __call__(self):
490
+ """
491
+ This function returns s = StructInstance(self)
492
+ s uses self.cls as template.
493
+ To enable autocomplete on s, we inherit from self.cls.
494
+ For example,
495
+
496
+ @wp.struct
497
+ class A:
498
+ # annotations
499
+ ...
500
+
501
+ The type annotations are inherited in A(), allowing autocomplete in kernels
502
+ """
503
+ # return StructInstance(self)
504
+
505
+ class NewStructInstance(self.cls, StructInstance):
506
+ def __init__(inst):
507
+ StructInstance.__init__(inst, self, None)
508
+
509
+ # make sure warp.types.get_type_code works with this StructInstance
510
+ NewStructInstance.cls = self.cls
511
+ NewStructInstance.native_name = self.native_name
512
+
513
+ return NewStructInstance()
514
+
515
+ def initializer(self):
516
+ return self.default_constructor
517
+
518
+ # return structured NumPy dtype, including field names, formats, and offsets
519
+ def numpy_dtype(self):
520
+ names = []
521
+ formats = []
522
+ offsets = []
523
+ for name, var in self.vars.items():
524
+ names.append(name)
525
+ offsets.append(getattr(self.ctype, name).offset)
526
+ if isinstance(var.type, array):
527
+ # array_t
528
+ formats.append(array_t.numpy_dtype())
529
+ elif isinstance(var.type, Struct):
530
+ # nested struct
531
+ formats.append(var.type.numpy_dtype())
532
+ elif issubclass(var.type, ctypes.Array):
533
+ scalar_typestr = type_typestr(var.type._wp_scalar_type_)
534
+ if len(var.type._shape_) == 1:
535
+ # vector
536
+ formats.append(f"{var.type._length_}{scalar_typestr}")
537
+ else:
538
+ # matrix
539
+ formats.append(f"{var.type._shape_}{scalar_typestr}")
540
+ else:
541
+ # scalar
542
+ formats.append(type_typestr(var.type))
543
+
544
+ return {"names": names, "formats": formats, "offsets": offsets, "itemsize": ctypes.sizeof(self.ctype)}
545
+
546
+ # constructs a Warp struct instance from a pointer to the ctype
547
+ def from_ptr(self, ptr):
548
+ if not ptr:
549
+ raise RuntimeError("NULL pointer exception")
550
+
551
+ # create a new struct instance
552
+ instance = self()
553
+
554
+ for name, var in self.vars.items():
555
+ offset = getattr(self.ctype, name).offset
556
+ if isinstance(var.type, array):
557
+ # We could reconstruct wp.array from array_t, but it's problematic.
558
+ # There's no guarantee that the original wp.array is still allocated and
559
+ # no easy way to make a backref.
560
+ # Instead, we just create a stub annotation, which is not a fully usable array object.
561
+ setattr(instance, name, array(dtype=var.type.dtype, ndim=var.type.ndim))
562
+ elif isinstance(var.type, Struct):
563
+ # nested struct
564
+ value = var.type.from_ptr(ptr + offset)
565
+ setattr(instance, name, value)
566
+ elif issubclass(var.type, ctypes.Array):
567
+ # vector/matrix
568
+ value = var.type.from_ptr(ptr + offset)
569
+ setattr(instance, name, value)
570
+ else:
571
+ # scalar
572
+ cvalue = ctypes.cast(ptr + offset, ctypes.POINTER(var.type._type_)).contents
573
+ if var.type == warp.float16:
574
+ setattr(instance, name, half_bits_to_float(cvalue))
575
+ else:
576
+ setattr(instance, name, cvalue.value)
577
+
578
+ return instance
579
+
580
+
581
+ class Reference:
582
+ def __init__(self, value_type):
583
+ self.value_type = value_type
584
+
585
+
586
+ def is_reference(type: Any) -> builtins.bool:
587
+ return isinstance(type, Reference)
588
+
589
+
590
+ def strip_reference(arg: Any) -> Any:
591
+ if is_reference(arg):
592
+ return arg.value_type
593
+ else:
594
+ return arg
595
+
596
+
597
+ def compute_type_str(base_name, template_params):
598
+ if not template_params:
599
+ return base_name
600
+
601
+ def param2str(p):
602
+ if isinstance(p, int):
603
+ return str(p)
604
+ elif hasattr(p, "_type_"):
605
+ if p.__name__ == "bool":
606
+ return "bool"
607
+ else:
608
+ return f"wp::{p.__name__}"
609
+ elif is_tile(p):
610
+ return p.ctype()
611
+
612
+ return p.__name__
613
+
614
+ return f"{base_name}<{','.join(map(param2str, template_params))}>"
615
+
616
+
617
+ class Var:
618
+ def __init__(
619
+ self,
620
+ label: str,
621
+ type: type,
622
+ requires_grad: builtins.bool = False,
623
+ constant: Optional[builtins.bool] = None,
624
+ prefix: builtins.bool = True,
625
+ relative_lineno: Optional[int] = None,
626
+ ):
627
+ # convert built-in types to wp types
628
+ if type == float:
629
+ type = float32
630
+ elif type == int:
631
+ type = int32
632
+ elif type == builtins.bool:
633
+ type = bool
634
+
635
+ self.label = label
636
+ self.type = type
637
+ self.requires_grad = requires_grad
638
+ self.constant = constant
639
+ self.prefix = prefix
640
+
641
+ # records whether this Var has been read from in a kernel function (array only)
642
+ self.is_read = False
643
+ # records whether this Var has been written to in a kernel function (array only)
644
+ self.is_write = False
645
+
646
+ # used to associate a view array Var with its parent array Var
647
+ self.parent = None
648
+
649
+ # Used to associate the variable with the Python statement that resulted in it being created.
650
+ self.relative_lineno = relative_lineno
651
+
652
+ def __str__(self):
653
+ return self.label
654
+
655
+ @staticmethod
656
+ def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
657
+ if is_array(t):
658
+ if hasattr(t.dtype, "_wp_generic_type_str_"):
659
+ dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
660
+ elif isinstance(t.dtype, Struct):
661
+ dtypestr = t.dtype.native_name
662
+ elif t.dtype.__name__ in ("bool", "int", "float"):
663
+ dtypestr = t.dtype.__name__
664
+ else:
665
+ dtypestr = f"wp::{t.dtype.__name__}"
666
+ classstr = f"wp::{type(t).__name__}"
667
+ return f"{classstr}_t<{dtypestr}>"
668
+ elif is_tile(t):
669
+ return t.ctype()
670
+ elif isinstance(t, Struct):
671
+ return t.native_name
672
+ elif isinstance(t, type) and issubclass(t, StructInstance):
673
+ # ensure the actual Struct name is used instead of "NewStructInstance"
674
+ return t.native_name
675
+ elif is_reference(t):
676
+ if not value_type:
677
+ return Var.type_to_ctype(t.value_type) + "*"
678
+ else:
679
+ return Var.type_to_ctype(t.value_type)
680
+ elif hasattr(t, "_wp_generic_type_str_"):
681
+ return compute_type_str(f"wp::{t._wp_generic_type_str_}", t._wp_type_params_)
682
+ elif t.__name__ in ("bool", "int", "float"):
683
+ return t.__name__
684
+ else:
685
+ return f"wp::{t.__name__}"
686
+
687
+ def ctype(self, value_type: builtins.bool = False) -> str:
688
+ return Var.type_to_ctype(self.type, value_type)
689
+
690
+ def emit(self, prefix: str = "var"):
691
+ if self.prefix:
692
+ return f"{prefix}_{self.label}"
693
+ else:
694
+ return self.label
695
+
696
+ def emit_adj(self):
697
+ return self.emit("adj")
698
+
699
+ def mark_read(self):
700
+ """Marks this Var as having been read from in a kernel (array only)."""
701
+ if not is_array(self.type):
702
+ return
703
+
704
+ self.is_read = True
705
+
706
+ # recursively update all parent states
707
+ parent = self.parent
708
+ while parent is not None:
709
+ parent.is_read = True
710
+ parent = parent.parent
711
+
712
+ def mark_write(self, **kwargs):
713
+ """Marks this Var has having been written to in a kernel (array only)."""
714
+ if not is_array(self.type):
715
+ return
716
+
717
+ # detect if we are writing to an array after reading from it within the same kernel
718
+ if self.is_read and warp.config.verify_autograd_array_access:
719
+ if "kernel_name" and "filename" and "lineno" in kwargs:
720
+ print(
721
+ f"Warning: Array passed to argument {self.label} in kernel {kwargs['kernel_name']} at {kwargs['filename']}:{kwargs['lineno']} is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass."
722
+ )
723
+ else:
724
+ print(
725
+ f"Warning: Array {self} is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass."
726
+ )
727
+ self.is_write = True
728
+
729
+ # recursively update all parent states
730
+ parent = self.parent
731
+ while parent is not None:
732
+ parent.is_write = True
733
+ parent = parent.parent
734
+
735
+
736
+ class Block:
737
+ # Represents a basic block of instructions, e.g.: list
738
+ # of straight line instructions inside a for-loop or conditional
739
+
740
+ def __init__(self):
741
+ # list of statements inside this block
742
+ self.body_forward = []
743
+ self.body_replay = []
744
+ self.body_reverse = []
745
+
746
+ # list of vars declared in this block
747
+ self.vars = []
748
+
749
+
750
+ def apply_defaults(
751
+ bound_args: inspect.BoundArguments,
752
+ values: Mapping[str, Any],
753
+ ):
754
+ # Similar to Python's `inspect.BoundArguments.apply_defaults()`
755
+ # but with the possibility to pass an augmented set of default values.
756
+ arguments = bound_args.arguments
757
+ new_arguments = []
758
+ for name in bound_args._signature.parameters.keys():
759
+ try:
760
+ new_arguments.append((name, arguments[name]))
761
+ except KeyError:
762
+ if name in values:
763
+ new_arguments.append((name, values[name]))
764
+
765
+ bound_args.arguments = dict(new_arguments)
766
+
767
+
768
+ def func_match_args(func, arg_types, kwarg_types):
769
+ try:
770
+ # Try to bind the given arguments to the function's signature.
771
+ # This is not checking whether the argument types are matching,
772
+ # rather it's just assigning each argument to the corresponding
773
+ # function parameter.
774
+ bound_arg_types = func.signature.bind(*arg_types, **kwarg_types)
775
+ except TypeError:
776
+ return False
777
+
778
+ # Populate the bound arguments with any default values.
779
+ default_arg_types = {
780
+ k: None if v is None else get_arg_type(v)
781
+ for k, v in func.defaults.items()
782
+ if k not in bound_arg_types.arguments
783
+ }
784
+ apply_defaults(bound_arg_types, default_arg_types)
785
+ bound_arg_types = tuple(bound_arg_types.arguments.values())
786
+
787
+ # Check the given argument types against the ones defined on the function.
788
+ for bound_arg_type, func_arg_type in zip(bound_arg_types, func.input_types.values()):
789
+ # Let the `value_func` callback infer the type.
790
+ if bound_arg_type is None:
791
+ continue
792
+
793
+ # if arg type registered as Any, treat as
794
+ # template allowing any type to match
795
+ if func_arg_type == Any:
796
+ continue
797
+
798
+ # handle function refs as a special case
799
+ if func_arg_type == Callable and isinstance(bound_arg_type, warp.context.Function):
800
+ continue
801
+
802
+ # check arg type matches input variable type
803
+ if not types_equal(func_arg_type, strip_reference(bound_arg_type), match_generic=True):
804
+ return False
805
+
806
+ return True
807
+
808
+
809
+ def get_arg_type(arg: Union[Var, Any]) -> type:
810
+ if isinstance(arg, str):
811
+ return str
812
+
813
+ if isinstance(arg, Sequence):
814
+ return tuple(get_arg_type(x) for x in arg)
815
+
816
+ if isinstance(arg, (type, warp.context.Function)):
817
+ return arg
818
+
819
+ if isinstance(arg, Var):
820
+ return arg.type
821
+
822
+ return type(arg)
823
+
824
+
825
+ def get_arg_value(arg: Any) -> Any:
826
+ if isinstance(arg, Sequence):
827
+ return tuple(get_arg_value(x) for x in arg)
828
+
829
+ if isinstance(arg, (type, warp.context.Function)):
830
+ return arg
831
+
832
+ if isinstance(arg, Var):
833
+ return arg.constant
834
+
835
+ return arg
836
+
837
+
838
+ class Adjoint:
839
+ # Source code transformer, this class takes a Python function and
840
+ # generates forward and backward SSA forms of the function instructions
841
+
842
+ def __init__(
843
+ adj,
844
+ func: Callable[..., Any],
845
+ overload_annotations=None,
846
+ is_user_function=False,
847
+ skip_forward_codegen=False,
848
+ skip_reverse_codegen=False,
849
+ custom_reverse_mode=False,
850
+ custom_reverse_num_input_args=-1,
851
+ transformers: Optional[List[ast.NodeTransformer]] = None,
852
+ ):
853
+ adj.func = func
854
+
855
+ adj.is_user_function = is_user_function
856
+
857
+ # whether the generation of the forward code is skipped for this function
858
+ adj.skip_forward_codegen = skip_forward_codegen
859
+ # whether the generation of the adjoint code is skipped for this function
860
+ adj.skip_reverse_codegen = skip_reverse_codegen
861
+
862
+ # extract name of source file
863
+ adj.filename = inspect.getsourcefile(func) or "unknown source file"
864
+ # get source file line number where function starts
865
+ try:
866
+ _, adj.fun_lineno = inspect.getsourcelines(func)
867
+ except OSError as e:
868
+ raise RuntimeError(
869
+ "Directly evaluating Warp code defined as a string using `exec()` is not supported, "
870
+ "please save it on a file and use `importlib` if needed."
871
+ ) from e
872
+
873
+ # Indicates where the function definition starts (excludes decorators)
874
+ adj.fun_def_lineno = None
875
+
876
+ # get function source code
877
+ adj.source = inspect.getsource(func)
878
+ # ensures that indented class methods can be parsed as kernels
879
+ adj.source = textwrap.dedent(adj.source)
880
+
881
+ adj.source_lines = adj.source.splitlines()
882
+
883
+ if transformers is None:
884
+ transformers = []
885
+
886
+ # build AST and apply node transformers
887
+ adj.tree = ast.parse(adj.source)
888
+ adj.transformers = transformers
889
+ for transformer in transformers:
890
+ adj.tree = transformer.visit(adj.tree)
891
+
892
+ adj.fun_name = adj.tree.body[0].name
893
+
894
+ # for keeping track of line number in function code
895
+ adj.lineno = None
896
+
897
+ # whether the forward code shall be used for the reverse pass and a custom
898
+ # function signature is applied to the reverse version of the function
899
+ adj.custom_reverse_mode = custom_reverse_mode
900
+ # the number of function arguments that pertain to the forward function
901
+ # input arguments (i.e. the number of arguments that are not adjoint arguments)
902
+ adj.custom_reverse_num_input_args = custom_reverse_num_input_args
903
+
904
+ # parse argument types
905
+ argspec = get_full_arg_spec(func)
906
+
907
+ # ensure all arguments are annotated
908
+ if overload_annotations is None:
909
+ # use source-level argument annotations
910
+ if len(argspec.annotations) < len(argspec.args):
911
+ raise WarpCodegenError(f"Incomplete argument annotations on function {adj.fun_name}")
912
+ adj.arg_types = {k: v for k, v in argspec.annotations.items() if not (k == "return" and v is None)}
913
+ else:
914
+ # use overload argument annotations
915
+ for arg_name in argspec.args:
916
+ if arg_name not in overload_annotations:
917
+ raise WarpCodegenError(f"Incomplete overload annotations for function {adj.fun_name}")
918
+ adj.arg_types = overload_annotations.copy()
919
+
920
+ adj.args = []
921
+ adj.symbols = {}
922
+
923
+ for name, type in adj.arg_types.items():
924
+ # skip return hint
925
+ if name == "return":
926
+ continue
927
+
928
+ # add variable for argument
929
+ arg = Var(name, type, False)
930
+ adj.args.append(arg)
931
+
932
+ # pre-populate symbol dictionary with function argument names
933
+ # this is to avoid registering false references to overshadowed modules
934
+ adj.symbols[name] = arg
935
+
936
+ # try to replace static expressions by their constant result if the
937
+ # expression can be evaluated at declaration time
938
+ adj.static_expressions: Dict[str, Any] = {}
939
+ if "static" in adj.source:
940
+ adj.replace_static_expressions()
941
+
942
+ # There are cases where a same module might be rebuilt multiple times,
943
+ # for example when kernels are nested inside of functions, or when
944
+ # a kernel's launch raises an exception. Ideally we'd always want to
945
+ # avoid rebuilding kernels but some corner cases seem to depend on it,
946
+ # so we only avoid rebuilding kernels that errored out to give a chance
947
+ # for unit testing errors being spit out from kernels.
948
+ adj.skip_build = False
949
+
950
+ # allocate extra space for a function call that requires its
951
+ # own shared memory space, we treat shared memory as a stack
952
+ # where each function pushes and pops space off, the extra
953
+ # quantity is the 'roofline' amount required for the entire kernel
954
+ def alloc_shared_extra(adj, num_bytes):
955
+ adj.max_required_extra_shared_memory = max(adj.max_required_extra_shared_memory, num_bytes)
956
+
957
+ # returns the total number of bytes for a function
958
+ # based on it's own requirements + worst case
959
+ # requirements of any dependent functions
960
+ def get_total_required_shared(adj):
961
+ total_shared = 0
962
+
963
+ for var in adj.variables:
964
+ if is_tile(var.type) and var.type.storage == "shared" and var.type.owner:
965
+ total_shared += var.type.size_in_bytes()
966
+
967
+ return total_shared + adj.max_required_extra_shared_memory
968
+
969
+ # generate function ssa form and adjoint
970
+ def build(adj, builder, default_builder_options=None):
971
+ # arg Var read/write flags are held during module rebuilds, so we reset here even when skipping a build
972
+ for arg in adj.args:
973
+ arg.is_read = False
974
+ arg.is_write = False
975
+
976
+ if adj.skip_build:
977
+ return
978
+
979
+ adj.builder = builder
980
+
981
+ if default_builder_options is None:
982
+ default_builder_options = {}
983
+
984
+ if adj.builder:
985
+ adj.builder_options = adj.builder.options
986
+ else:
987
+ adj.builder_options = default_builder_options
988
+
989
+ global options
990
+ options = adj.builder_options
991
+
992
+ adj.symbols = {} # map from symbols to adjoint variables
993
+ adj.variables = [] # list of local variables (in order)
994
+
995
+ adj.return_var = None # return type for function or kernel
996
+ adj.loop_symbols = [] # symbols at the start of each loop
997
+ adj.loop_const_iter_symbols = (
998
+ set()
999
+ ) # constant iteration variables for static loops (mutating them does not raise an error)
1000
+
1001
+ # blocks
1002
+ adj.blocks = [Block()]
1003
+ adj.loop_blocks = []
1004
+
1005
+ # holds current indent level
1006
+ adj.indentation = ""
1007
+
1008
+ # used to generate new label indices
1009
+ adj.label_count = 0
1010
+
1011
+ # tracks how much additional shared memory is required by any dependent function calls
1012
+ adj.max_required_extra_shared_memory = 0
1013
+
1014
+ # update symbol map for each argument
1015
+ for a in adj.args:
1016
+ adj.symbols[a.label] = a
1017
+
1018
+ # recursively evaluate function body
1019
+ try:
1020
+ adj.eval(adj.tree.body[0])
1021
+ except Exception:
1022
+ try:
1023
+ lineno = adj.lineno + adj.fun_lineno
1024
+ line = adj.source_lines[adj.lineno]
1025
+ msg = f'Error while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
1026
+ ex, data, traceback = sys.exc_info()
1027
+ e = ex(";".join([msg] + [str(a) for a in data.args])).with_traceback(traceback)
1028
+ finally:
1029
+ adj.skip_build = True
1030
+ adj.builder = None
1031
+ raise e
1032
+
1033
+ if builder is not None:
1034
+ for a in adj.args:
1035
+ if isinstance(a.type, Struct):
1036
+ builder.build_struct_recursive(a.type)
1037
+ elif isinstance(a.type, warp.types.array) and isinstance(a.type.dtype, Struct):
1038
+ builder.build_struct_recursive(a.type.dtype)
1039
+
1040
+ # release builder reference for GC
1041
+ adj.builder = None
1042
+
1043
+ # code generation methods
1044
+ def format_template(adj, template, input_vars, output_var):
1045
+ # output var is always the 0th index
1046
+ args = [output_var] + input_vars
1047
+ s = template.format(*args)
1048
+
1049
+ return s
1050
+
1051
+ # generates a list of formatted args
1052
+ def format_args(adj, prefix, args):
1053
+ arg_strs = []
1054
+
1055
+ for a in args:
1056
+ if isinstance(a, warp.context.Function):
1057
+ # functions don't have a var_ prefix so strip it off here
1058
+ if prefix == "var":
1059
+ arg_strs.append(f"{a.namespace}{a.native_func}")
1060
+ else:
1061
+ arg_strs.append(f"{a.namespace}{prefix}_{a.native_func}")
1062
+ elif is_reference(a.type):
1063
+ arg_strs.append(f"{prefix}_{a}")
1064
+ elif isinstance(a, Var):
1065
+ arg_strs.append(a.emit(prefix))
1066
+ else:
1067
+ raise WarpCodegenTypeError(f"Arguments must be variables or functions, got {type(a)}")
1068
+
1069
+ return arg_strs
1070
+
1071
+ # generates argument string for a forward function call
1072
+ def format_forward_call_args(adj, args, use_initializer_list):
1073
+ arg_str = ", ".join(adj.format_args("var", args))
1074
+ if use_initializer_list:
1075
+ return f"{{{arg_str}}}"
1076
+ return arg_str
1077
+
1078
+ # generates argument string for a reverse function call
1079
+ def format_reverse_call_args(
1080
+ adj,
1081
+ args_var,
1082
+ args,
1083
+ args_out,
1084
+ use_initializer_list,
1085
+ has_output_args=True,
1086
+ require_original_output_arg=False,
1087
+ ):
1088
+ formatted_var = adj.format_args("var", args_var)
1089
+ formatted_out = []
1090
+ if has_output_args and (require_original_output_arg or len(args_out) > 1):
1091
+ formatted_out = adj.format_args("var", args_out)
1092
+ formatted_var_adj = adj.format_args(
1093
+ "&adj" if use_initializer_list else "adj",
1094
+ args,
1095
+ )
1096
+ formatted_out_adj = adj.format_args("adj", args_out)
1097
+
1098
+ if len(formatted_var_adj) == 0 and len(formatted_out_adj) == 0:
1099
+ # there are no adjoint arguments, so we don't need to call the reverse function
1100
+ return None
1101
+
1102
+ if use_initializer_list:
1103
+ var_str = f"{{{', '.join(formatted_var)}}}"
1104
+ out_str = f"{{{', '.join(formatted_out)}}}"
1105
+ adj_str = f"{{{', '.join(formatted_var_adj)}}}"
1106
+ out_adj_str = ", ".join(formatted_out_adj)
1107
+ if len(args_out) > 1:
1108
+ arg_str = ", ".join([var_str, out_str, adj_str, out_adj_str])
1109
+ else:
1110
+ arg_str = ", ".join([var_str, adj_str, out_adj_str])
1111
+ else:
1112
+ arg_str = ", ".join(formatted_var + formatted_out + formatted_var_adj + formatted_out_adj)
1113
+ return arg_str
1114
+
1115
+ def indent(adj):
1116
+ adj.indentation = adj.indentation + " "
1117
+
1118
+ def dedent(adj):
1119
+ adj.indentation = adj.indentation[:-4]
1120
+
1121
+ def begin_block(adj, name="block"):
1122
+ b = Block()
1123
+
1124
+ # give block a unique id
1125
+ b.label = name + "_" + str(adj.label_count)
1126
+ adj.label_count += 1
1127
+
1128
+ adj.blocks.append(b)
1129
+ return b
1130
+
1131
+ def end_block(adj):
1132
+ return adj.blocks.pop()
1133
+
1134
+ def add_var(adj, type=None, constant=None):
1135
+ index = len(adj.variables)
1136
+ name = str(index)
1137
+
1138
+ # allocate new variable
1139
+ v = Var(name, type=type, constant=constant, relative_lineno=adj.lineno)
1140
+
1141
+ adj.variables.append(v)
1142
+
1143
+ adj.blocks[-1].vars.append(v)
1144
+
1145
+ return v
1146
+
1147
+ def register_var(adj, var):
1148
+ # We sometimes initialize `Var` instances that might be thrown away
1149
+ # afterwards, so this method allows to defer their registration among
1150
+ # the list of primal vars until later on, instead of registering them
1151
+ # immediately if we were to use `adj.add_var()` or `adj.add_constant()`.
1152
+
1153
+ if isinstance(var, (Reference, warp.context.Function)):
1154
+ return var
1155
+
1156
+ if isinstance(var, int):
1157
+ return adj.add_constant(var)
1158
+
1159
+ if var.label is None:
1160
+ return adj.add_var(var.type, var.constant)
1161
+
1162
+ return var
1163
+
1164
+ def get_line_directive(adj, statement: str, relative_lineno: Optional[int] = None) -> Optional[str]:
1165
+ """Get a line directive for the given statement.
1166
+
1167
+ Args:
1168
+ statement: The statement to get the line directive for.
1169
+ relative_lineno: The line number of the statement relative to the function.
1170
+
1171
+ Returns:
1172
+ A line directive for the given statement, or None if no line directive is needed.
1173
+ """
1174
+
1175
+ # lineinfo is enabled by default in debug mode regardless of the builder option, don't want to unnecessarily
1176
+ # emit line directives in generated code if it's not being compiled with line information
1177
+ lineinfo_enabled = (
1178
+ adj.builder_options.get("lineinfo", False) or adj.builder_options.get("mode", "release") == "debug"
1179
+ )
1180
+
1181
+ if relative_lineno is not None and lineinfo_enabled and warp.config.line_directives:
1182
+ is_comment = statement.strip().startswith("//")
1183
+ if not is_comment:
1184
+ line = relative_lineno + adj.fun_lineno
1185
+ # Convert backslashes to forward slashes for CUDA compatibility
1186
+ normalized_path = adj.filename.replace("\\", "/")
1187
+ return f'#line {line} "{normalized_path}"'
1188
+ return None
1189
+
1190
+ def add_forward(adj, statement: str, replay: Optional[str] = None, skip_replay: builtins.bool = False) -> None:
1191
+ """Append a statement to the forward pass."""
1192
+
1193
+ if line_directive := adj.get_line_directive(statement, adj.lineno):
1194
+ adj.blocks[-1].body_forward.append(line_directive)
1195
+
1196
+ adj.blocks[-1].body_forward.append(adj.indentation + statement)
1197
+
1198
+ if not skip_replay:
1199
+ if line_directive:
1200
+ adj.blocks[-1].body_replay.append(line_directive)
1201
+
1202
+ if replay:
1203
+ # if custom replay specified then output it
1204
+ adj.blocks[-1].body_replay.append(adj.indentation + replay)
1205
+ else:
1206
+ # by default just replay the original statement
1207
+ adj.blocks[-1].body_replay.append(adj.indentation + statement)
1208
+
1209
+ # append a statement to the reverse pass
1210
+ def add_reverse(adj, statement: str) -> None:
1211
+ """Append a statement to the reverse pass."""
1212
+
1213
+ adj.blocks[-1].body_reverse.append(adj.indentation + statement)
1214
+
1215
+ if line_directive := adj.get_line_directive(statement, adj.lineno):
1216
+ adj.blocks[-1].body_reverse.append(line_directive)
1217
+
1218
+ def add_constant(adj, n):
1219
+ output = adj.add_var(type=type(n), constant=n)
1220
+ return output
1221
+
1222
+ def load(adj, var):
1223
+ if is_reference(var.type):
1224
+ var = adj.add_builtin_call("load", [var])
1225
+ return var
1226
+
1227
+ def add_comp(adj, op_strings, left, comps):
1228
+ output = adj.add_var(builtins.bool)
1229
+
1230
+ left = adj.load(left)
1231
+ s = output.emit() + " = " + ("(" * len(comps)) + left.emit() + " "
1232
+
1233
+ prev_comp_var = None
1234
+
1235
+ for op, comp in zip(op_strings, comps):
1236
+ comp_chainable = op_str_is_chainable(op)
1237
+ if comp_chainable and prev_comp_var:
1238
+ # We restrict chaining to operands of the same type
1239
+ if prev_comp_var.type is comp.type:
1240
+ prev_comp_var = adj.load(prev_comp_var)
1241
+ comp_var = adj.load(comp)
1242
+ s += "&& (" + prev_comp_var.emit() + " " + op + " " + comp_var.emit() + ")) "
1243
+ else:
1244
+ raise WarpCodegenTypeError(
1245
+ f"Cannot chain comparisons of unequal types: {prev_comp_var.type} {op} {comp.type}."
1246
+ )
1247
+ else:
1248
+ comp_var = adj.load(comp)
1249
+ s += op + " " + comp_var.emit() + ") "
1250
+
1251
+ prev_comp_var = comp_var
1252
+
1253
+ s = s.rstrip() + ";"
1254
+
1255
+ adj.add_forward(s)
1256
+
1257
+ return output
1258
+
1259
+ def add_bool_op(adj, op_string, exprs):
1260
+ exprs = [adj.load(expr) for expr in exprs]
1261
+ output = adj.add_var(builtins.bool)
1262
+ command = output.emit() + " = " + (" " + op_string + " ").join([expr.emit() for expr in exprs]) + ";"
1263
+ adj.add_forward(command)
1264
+
1265
+ return output
1266
+
1267
+ def resolve_func(adj, func, arg_types, kwarg_types, min_outputs):
1268
+ if not func.is_builtin():
1269
+ # user-defined function
1270
+ overload = func.get_overload(arg_types, kwarg_types)
1271
+ if overload is not None:
1272
+ return overload
1273
+ else:
1274
+ # if func is overloaded then perform overload resolution here
1275
+ # we validate argument types before they go to generated native code
1276
+ for f in func.overloads:
1277
+ # skip type checking for variadic functions
1278
+ if not f.variadic:
1279
+ # check argument counts match are compatible (may be some default args)
1280
+ if len(f.input_types) < len(arg_types) + len(kwarg_types):
1281
+ continue
1282
+
1283
+ if not func_match_args(f, arg_types, kwarg_types):
1284
+ continue
1285
+
1286
+ # check output dimensions match expectations
1287
+ if min_outputs:
1288
+ if not isinstance(f.value_type, Sequence) or len(f.value_type) != min_outputs:
1289
+ continue
1290
+
1291
+ # found a match, use it
1292
+ return f
1293
+
1294
+ # unresolved function, report error
1295
+ arg_type_reprs = []
1296
+
1297
+ for x in arg_types:
1298
+ if isinstance(x, warp.context.Function):
1299
+ arg_type_reprs.append("function")
1300
+ else:
1301
+ # shorten Warp primitive type names
1302
+ if isinstance(x, Sequence):
1303
+ if len(x) != 1:
1304
+ raise WarpCodegenError("Argument must not be the result from a multi-valued function")
1305
+ arg_type = x[0]
1306
+ else:
1307
+ arg_type = x
1308
+
1309
+ arg_type_reprs.append(type_repr(arg_type))
1310
+
1311
+ raise WarpCodegenError(
1312
+ f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_type_reprs)}]"
1313
+ )
1314
+
1315
+ def add_call(adj, func, args, kwargs, type_args, min_outputs=None):
1316
+ # Extract the types and values passed as arguments to the function call.
1317
+ arg_types = tuple(strip_reference(get_arg_type(x)) for x in args)
1318
+ kwarg_types = {k: strip_reference(get_arg_type(v)) for k, v in kwargs.items()}
1319
+
1320
+ # Resolve the exact function signature among any existing overload.
1321
+ func = adj.resolve_func(func, arg_types, kwarg_types, min_outputs)
1322
+
1323
+ # Bind the positional and keyword arguments to the function's signature
1324
+ # in order to process them as Python does it.
1325
+ bound_args: inspect.BoundArguments = func.signature.bind(*args, **kwargs)
1326
+
1327
+ # Type args are the “compile time” argument values we get from codegen.
1328
+ # For example, when calling `wp.vec3f(...)` from within a kernel,
1329
+ # this translates in fact to calling the `vector()` built-in augmented
1330
+ # with the type args `length=3, dtype=float`.
1331
+ # Eventually, these need to be passed to the underlying C++ function,
1332
+ # so we update the arguments with the type args here.
1333
+ if type_args:
1334
+ for arg in type_args:
1335
+ if arg in bound_args.arguments:
1336
+ # In case of conflict, ideally we'd throw an error since
1337
+ # what comes from codegen should be the source of truth
1338
+ # and users also passing the same value as an argument
1339
+ # is redundant (e.g.: `wp.mat22(shape=(2, 2))`).
1340
+ # However, for backward compatibility, we allow that form
1341
+ # as long as the values are equal.
1342
+ if values_check_equal(get_arg_value(bound_args.arguments[arg]), type_args[arg]):
1343
+ continue
1344
+
1345
+ raise RuntimeError(
1346
+ f"Remove the extraneous `{arg}` parameter "
1347
+ f"when calling the templated version of "
1348
+ f"`wp.{func.native_func}()`"
1349
+ )
1350
+
1351
+ type_vars = {k: Var(None, type=type(v), constant=v) for k, v in type_args.items()}
1352
+ apply_defaults(bound_args, type_vars)
1353
+
1354
+ if func.defaults:
1355
+ default_vars = {
1356
+ k: Var(None, type=type(v), constant=v)
1357
+ for k, v in func.defaults.items()
1358
+ if k not in bound_args.arguments and v is not None
1359
+ }
1360
+ apply_defaults(bound_args, default_vars)
1361
+
1362
+ bound_args = bound_args.arguments
1363
+
1364
+ # if it is a user-function then build it recursively
1365
+ if not func.is_builtin() and func not in adj.builder.functions:
1366
+ adj.builder.build_function(func)
1367
+ # add custom grad, replay functions to the list of functions
1368
+ # to be built later (invalid code could be generated if we built them now)
1369
+ # so that they are not missed when only the forward function is imported
1370
+ # from another module
1371
+ if func.custom_grad_func:
1372
+ adj.builder.deferred_functions.append(func.custom_grad_func)
1373
+ if func.custom_replay_func:
1374
+ adj.builder.deferred_functions.append(func.custom_replay_func)
1375
+
1376
+ # Resolve the return value based on the types and values of the given arguments.
1377
+ bound_arg_types = {k: get_arg_type(v) for k, v in bound_args.items()}
1378
+ bound_arg_values = {k: get_arg_value(v) for k, v in bound_args.items()}
1379
+ return_type = func.value_func(
1380
+ {k: strip_reference(v) for k, v in bound_arg_types.items()},
1381
+ bound_arg_values,
1382
+ )
1383
+
1384
+ # immediately allocate output variables so we can pass them into the dispatch method
1385
+ if return_type is None:
1386
+ # void function
1387
+ output = None
1388
+ output_list = []
1389
+ elif not isinstance(return_type, Sequence) or len(return_type) == 1:
1390
+ # single return value function
1391
+ if isinstance(return_type, Sequence):
1392
+ return_type = return_type[0]
1393
+ output = adj.add_var(return_type)
1394
+ output_list = [output]
1395
+ else:
1396
+ # multiple return value function
1397
+ output = [adj.add_var(v) for v in return_type]
1398
+ output_list = output
1399
+
1400
+ # If we have a built-in that requires special handling to dispatch
1401
+ # the arguments to the underlying C++ function, then we can resolve
1402
+ # these using the `dispatch_func`. Since this is only called from
1403
+ # within codegen, we pass it directly `codegen.Var` objects,
1404
+ # which allows for some more advanced resolution to be performed,
1405
+ # for example by checking whether an argument corresponds to
1406
+ # a literal value or references a variable.
1407
+ extra_shared_memory = 0
1408
+ if func.lto_dispatch_func is not None:
1409
+ func_args, template_args, ltoirs, extra_shared_memory = func.lto_dispatch_func(
1410
+ func.input_types, return_type, output_list, bound_args, options=adj.builder_options, builder=adj.builder
1411
+ )
1412
+ elif func.dispatch_func is not None:
1413
+ func_args, template_args = func.dispatch_func(func.input_types, return_type, bound_args)
1414
+ else:
1415
+ func_args = tuple(bound_args.values())
1416
+ template_args = ()
1417
+
1418
+ func_args = tuple(adj.register_var(x) for x in func_args)
1419
+ func_name = compute_type_str(func.native_func, template_args)
1420
+ use_initializer_list = func.initializer_list_func(bound_args, return_type)
1421
+
1422
+ fwd_args = []
1423
+ for func_arg in func_args:
1424
+ if not isinstance(func_arg, (Reference, warp.context.Function)):
1425
+ func_arg_var = adj.load(func_arg)
1426
+ else:
1427
+ func_arg_var = func_arg
1428
+
1429
+ # if the argument is a function (and not a builtin), then build it recursively
1430
+ if isinstance(func_arg_var, warp.context.Function) and not func_arg_var.is_builtin():
1431
+ adj.builder.build_function(func_arg_var)
1432
+
1433
+ fwd_args.append(strip_reference(func_arg_var))
1434
+
1435
+ if return_type is None:
1436
+ # handles expression (zero output) functions, e.g.: void do_something();
1437
+ forward_call = (
1438
+ f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1439
+ )
1440
+ replay_call = forward_call
1441
+ if func.custom_replay_func is not None or func.replay_snippet is not None:
1442
+ replay_call = f"{func.namespace}replay_{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1443
+
1444
+ elif not isinstance(return_type, Sequence) or len(return_type) == 1:
1445
+ # handle simple function (one output)
1446
+ forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1447
+ replay_call = forward_call
1448
+ if func.custom_replay_func is not None:
1449
+ replay_call = f"var_{output} = {func.namespace}replay_{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1450
+
1451
+ else:
1452
+ # handle multiple value functions
1453
+ forward_call = (
1454
+ f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args + output, use_initializer_list)});"
1455
+ )
1456
+ replay_call = forward_call
1457
+
1458
+ if func.skip_replay:
1459
+ adj.add_forward(forward_call, replay="// " + replay_call)
1460
+ else:
1461
+ adj.add_forward(forward_call, replay=replay_call)
1462
+
1463
+ if not func.missing_grad and len(func_args):
1464
+ adj_args = tuple(strip_reference(x) for x in func_args)
1465
+ reverse_has_output_args = (
1466
+ func.require_original_output_arg or len(output_list) > 1
1467
+ ) and func.custom_grad_func is None
1468
+ arg_str = adj.format_reverse_call_args(
1469
+ fwd_args,
1470
+ adj_args,
1471
+ output_list,
1472
+ use_initializer_list,
1473
+ has_output_args=reverse_has_output_args,
1474
+ require_original_output_arg=func.require_original_output_arg,
1475
+ )
1476
+ if arg_str is not None:
1477
+ reverse_call = f"{func.namespace}adj_{func.native_func}({arg_str});"
1478
+ adj.add_reverse(reverse_call)
1479
+
1480
+ # update our smem roofline requirements based on any
1481
+ # shared memory required by the dependent function call
1482
+ if not func.is_builtin():
1483
+ adj.alloc_shared_extra(func.adj.get_total_required_shared() + extra_shared_memory)
1484
+ else:
1485
+ adj.alloc_shared_extra(extra_shared_memory)
1486
+
1487
+ return output
1488
+
1489
+ def add_builtin_call(adj, func_name, args, min_outputs=None):
1490
+ func = warp.context.builtin_functions[func_name]
1491
+ return adj.add_call(func, args, {}, {}, min_outputs=min_outputs)
1492
+
1493
+ def add_return(adj, var):
1494
+ if var is None or len(var) == 0:
1495
+ adj.add_forward("return;", f"goto label{adj.label_count};")
1496
+ elif len(var) == 1:
1497
+ adj.add_forward(f"return {var[0].emit()};", f"goto label{adj.label_count};")
1498
+ adj.add_reverse("adj_" + str(var[0]) + " += adj_ret;")
1499
+ else:
1500
+ for i, v in enumerate(var):
1501
+ adj.add_forward(f"ret_{i} = {v.emit()};")
1502
+ adj.add_reverse(f"adj_{v} += adj_ret_{i};")
1503
+ adj.add_forward("return;", f"goto label{adj.label_count};")
1504
+
1505
+ adj.add_reverse(f"label{adj.label_count}:;")
1506
+
1507
+ adj.label_count += 1
1508
+
1509
+ # define an if statement
1510
+ def begin_if(adj, cond):
1511
+ cond = adj.load(cond)
1512
+ adj.add_forward(f"if ({cond.emit()}) {{")
1513
+ adj.add_reverse("}")
1514
+
1515
+ adj.indent()
1516
+
1517
+ def end_if(adj, cond):
1518
+ adj.dedent()
1519
+
1520
+ adj.add_forward("}")
1521
+ cond = adj.load(cond)
1522
+ adj.add_reverse(f"if ({cond.emit()}) {{")
1523
+
1524
+ def begin_else(adj, cond):
1525
+ cond = adj.load(cond)
1526
+ adj.add_forward(f"if (!{cond.emit()}) {{")
1527
+ adj.add_reverse("}")
1528
+
1529
+ adj.indent()
1530
+
1531
+ def end_else(adj, cond):
1532
+ adj.dedent()
1533
+
1534
+ adj.add_forward("}")
1535
+ cond = adj.load(cond)
1536
+ adj.add_reverse(f"if (!{cond.emit()}) {{")
1537
+
1538
+ # define a for-loop
1539
+ def begin_for(adj, iter):
1540
+ cond_block = adj.begin_block("for")
1541
+ adj.loop_blocks.append(cond_block)
1542
+ adj.add_forward(f"start_{cond_block.label}:;")
1543
+ adj.indent()
1544
+
1545
+ # evaluate cond
1546
+ adj.add_forward(f"if (iter_cmp({iter.emit()}) == 0) goto end_{cond_block.label};")
1547
+
1548
+ # evaluate iter
1549
+ val = adj.add_builtin_call("iter_next", [iter])
1550
+
1551
+ adj.begin_block()
1552
+
1553
+ return val
1554
+
1555
+ def end_for(adj, iter):
1556
+ body_block = adj.end_block()
1557
+ cond_block = adj.end_block()
1558
+ adj.loop_blocks.pop()
1559
+
1560
+ ####################
1561
+ # forward pass
1562
+
1563
+ for i in cond_block.body_forward:
1564
+ adj.blocks[-1].body_forward.append(i)
1565
+
1566
+ for i in body_block.body_forward:
1567
+ adj.blocks[-1].body_forward.append(i)
1568
+
1569
+ adj.add_forward(f"goto start_{cond_block.label};", skip_replay=True)
1570
+
1571
+ adj.dedent()
1572
+ adj.add_forward(f"end_{cond_block.label}:;", skip_replay=True)
1573
+
1574
+ ####################
1575
+ # reverse pass
1576
+
1577
+ reverse = []
1578
+
1579
+ # reverse iterator
1580
+ reverse.append(adj.indentation + f"{iter.emit()} = wp::iter_reverse({iter.emit()});")
1581
+
1582
+ for i in cond_block.body_forward:
1583
+ reverse.append(i)
1584
+
1585
+ # zero adjoints
1586
+ for i in body_block.vars:
1587
+ if is_tile(i.type):
1588
+ if i.type.owner:
1589
+ reverse.append(adj.indentation + f"\t{i.emit_adj()}.grad_zero();")
1590
+ else:
1591
+ reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
1592
+
1593
+ # replay
1594
+ for i in body_block.body_replay:
1595
+ reverse.append(i)
1596
+
1597
+ # reverse
1598
+ for i in reversed(body_block.body_reverse):
1599
+ reverse.append(i)
1600
+
1601
+ reverse.append(adj.indentation + f"\tgoto start_{cond_block.label};")
1602
+ reverse.append(adj.indentation + f"end_{cond_block.label}:;")
1603
+
1604
+ adj.blocks[-1].body_reverse.extend(reversed(reverse))
1605
+
1606
+ # define a while loop
1607
+ def begin_while(adj, cond):
1608
+ # evaluate condition in its own block
1609
+ # so we can control replay
1610
+ cond_block = adj.begin_block("while")
1611
+ adj.loop_blocks.append(cond_block)
1612
+ cond_block.body_forward.append(f"start_{cond_block.label}:;")
1613
+
1614
+ c = adj.eval(cond)
1615
+ c = adj.load(c)
1616
+
1617
+ cond_block.body_forward.append(f"if (({c.emit()}) == false) goto end_{cond_block.label};")
1618
+
1619
+ # being block around loop
1620
+ adj.begin_block()
1621
+ adj.indent()
1622
+
1623
+ def end_while(adj):
1624
+ adj.dedent()
1625
+ body_block = adj.end_block()
1626
+ cond_block = adj.end_block()
1627
+ adj.loop_blocks.pop()
1628
+
1629
+ ####################
1630
+ # forward pass
1631
+
1632
+ for i in cond_block.body_forward:
1633
+ adj.blocks[-1].body_forward.append(i)
1634
+
1635
+ for i in body_block.body_forward:
1636
+ adj.blocks[-1].body_forward.append(i)
1637
+
1638
+ adj.blocks[-1].body_forward.append(f"goto start_{cond_block.label};")
1639
+ adj.blocks[-1].body_forward.append(f"end_{cond_block.label}:;")
1640
+
1641
+ ####################
1642
+ # reverse pass
1643
+ reverse = []
1644
+
1645
+ # cond
1646
+ for i in cond_block.body_forward:
1647
+ reverse.append(i)
1648
+
1649
+ # zero adjoints of local vars
1650
+ for i in body_block.vars:
1651
+ reverse.append(f"{i.emit_adj()} = {{}};")
1652
+
1653
+ # replay
1654
+ for i in body_block.body_replay:
1655
+ reverse.append(i)
1656
+
1657
+ # reverse
1658
+ for i in reversed(body_block.body_reverse):
1659
+ reverse.append(i)
1660
+
1661
+ reverse.append(f"goto start_{cond_block.label};")
1662
+ reverse.append(f"end_{cond_block.label}:;")
1663
+
1664
+ # output
1665
+ adj.blocks[-1].body_reverse.extend(reversed(reverse))
1666
+
1667
+ def emit_FunctionDef(adj, node):
1668
+ adj.fun_def_lineno = node.lineno
1669
+
1670
+ for f in node.body:
1671
+ # Skip variable creation for standalone constants, including docstrings
1672
+ if isinstance(f, ast.Expr) and isinstance(f.value, ast.Constant):
1673
+ continue
1674
+ adj.eval(f)
1675
+
1676
+ if adj.return_var is not None and len(adj.return_var) == 1:
1677
+ if not isinstance(node.body[-1], ast.Return):
1678
+ adj.add_forward("return {};", skip_replay=True)
1679
+
1680
+ # native function case: return type is specified, eg -> int or -> wp.float32
1681
+ is_func_native = False
1682
+ if node.decorator_list is not None and len(node.decorator_list) == 1:
1683
+ obj = node.decorator_list[0]
1684
+ if isinstance(obj, ast.Call):
1685
+ if isinstance(obj.func, ast.Attribute):
1686
+ if obj.func.attr == "func_native":
1687
+ is_func_native = True
1688
+ if is_func_native and node.returns is not None:
1689
+ if isinstance(node.returns, ast.Name): # python built-in type
1690
+ var = Var(label="return_type", type=eval(node.returns.id))
1691
+ elif isinstance(node.returns, ast.Attribute): # warp type
1692
+ var = Var(label="return_type", type=eval(node.returns.attr))
1693
+ else:
1694
+ raise WarpCodegenTypeError("Native function return type not recognized")
1695
+ adj.return_var = (var,)
1696
+
1697
+ def emit_If(adj, node):
1698
+ if len(node.body) == 0:
1699
+ return None
1700
+
1701
+ # eval condition
1702
+ cond = adj.eval(node.test)
1703
+
1704
+ if cond.constant is not None:
1705
+ # resolve constant condition
1706
+ if cond.constant:
1707
+ for stmt in node.body:
1708
+ adj.eval(stmt)
1709
+ else:
1710
+ for stmt in node.orelse:
1711
+ adj.eval(stmt)
1712
+ return None
1713
+
1714
+ # save symbol map
1715
+ symbols_prev = adj.symbols.copy()
1716
+
1717
+ # eval body
1718
+ adj.begin_if(cond)
1719
+
1720
+ for stmt in node.body:
1721
+ adj.eval(stmt)
1722
+
1723
+ adj.end_if(cond)
1724
+
1725
+ # detect existing symbols with conflicting definitions (variables assigned inside the branch)
1726
+ # and resolve with a phi (select) function
1727
+ for items in symbols_prev.items():
1728
+ sym = items[0]
1729
+ var1 = items[1]
1730
+ var2 = adj.symbols[sym]
1731
+
1732
+ if var1 != var2:
1733
+ # insert a phi function that selects var1, var2 based on cond
1734
+ out = adj.add_builtin_call("where", [cond, var2, var1])
1735
+ adj.symbols[sym] = out
1736
+
1737
+ symbols_prev = adj.symbols.copy()
1738
+
1739
+ # evaluate 'else' statement as if (!cond)
1740
+ if len(node.orelse) > 0:
1741
+ adj.begin_else(cond)
1742
+
1743
+ for stmt in node.orelse:
1744
+ adj.eval(stmt)
1745
+
1746
+ adj.end_else(cond)
1747
+
1748
+ # detect existing symbols with conflicting definitions (variables assigned inside the else)
1749
+ # and resolve with a phi (select) function
1750
+ for items in symbols_prev.items():
1751
+ sym = items[0]
1752
+ var1 = items[1]
1753
+ var2 = adj.symbols[sym]
1754
+
1755
+ if var1 != var2:
1756
+ # insert a phi function that selects var1, var2 based on cond
1757
+ # note the reversed order of vars since we want to use !cond as our select
1758
+ out = adj.add_builtin_call("where", [cond, var1, var2])
1759
+ adj.symbols[sym] = out
1760
+
1761
+ def emit_Compare(adj, node):
1762
+ # node.left, node.ops (list of ops), node.comparators (things to compare to)
1763
+ # e.g. (left ops[0] node.comparators[0]) ops[1] node.comparators[1]
1764
+
1765
+ left = adj.eval(node.left)
1766
+ comps = [adj.eval(comp) for comp in node.comparators]
1767
+ op_strings = [builtin_operators[type(op)] for op in node.ops]
1768
+
1769
+ return adj.add_comp(op_strings, left, comps)
1770
+
1771
+ def emit_BoolOp(adj, node):
1772
+ # op, expr list values
1773
+
1774
+ op = node.op
1775
+ if isinstance(op, ast.And):
1776
+ func = "&&"
1777
+ elif isinstance(op, ast.Or):
1778
+ func = "||"
1779
+ else:
1780
+ raise WarpCodegenKeyError(f"Op {op} is not supported")
1781
+
1782
+ return adj.add_bool_op(func, [adj.eval(expr) for expr in node.values])
1783
+
1784
+ def emit_Name(adj, node):
1785
+ # lookup symbol, if it has already been assigned to a variable then return the existing mapping
1786
+ if node.id in adj.symbols:
1787
+ return adj.symbols[node.id]
1788
+
1789
+ obj = adj.resolve_external_reference(node.id)
1790
+
1791
+ if obj is None:
1792
+ raise WarpCodegenKeyError("Referencing undefined symbol: " + str(node.id))
1793
+
1794
+ if warp.types.is_value(obj):
1795
+ # evaluate constant
1796
+ out = adj.add_constant(obj)
1797
+ adj.symbols[node.id] = out
1798
+ return out
1799
+
1800
+ # the named object is either a function, class name, or module
1801
+ # pass it back to the caller for processing
1802
+ if isinstance(obj, warp.context.Function):
1803
+ return obj
1804
+ if isinstance(obj, type):
1805
+ return obj
1806
+ if isinstance(obj, types.ModuleType):
1807
+ return obj
1808
+
1809
+ raise TypeError(f"Invalid external reference type: {type(obj)}")
1810
+
1811
+ @staticmethod
1812
+ def resolve_type_attribute(var_type: type, attr: str):
1813
+ if isinstance(var_type, type) and type_is_value(var_type):
1814
+ if attr == "dtype":
1815
+ return type_scalar_type(var_type)
1816
+ elif attr == "length":
1817
+ return type_length(var_type)
1818
+
1819
+ return getattr(var_type, attr, None)
1820
+
1821
+ def vector_component_index(adj, component, vector_type):
1822
+ if len(component) != 1:
1823
+ raise WarpCodegenAttributeError(f"Vector swizzle must be single character, got .{component}")
1824
+
1825
+ dim = vector_type._shape_[0]
1826
+ swizzles = "xyzw"[0:dim]
1827
+ if component not in swizzles:
1828
+ raise WarpCodegenAttributeError(
1829
+ f"Vector swizzle for {vector_type} must be one of {swizzles}, got {component}"
1830
+ )
1831
+
1832
+ index = swizzles.index(component)
1833
+ index = adj.add_constant(index)
1834
+ return index
1835
+
1836
+ @staticmethod
1837
+ def is_differentiable_value_type(var_type):
1838
+ # checks that the argument type is a value type (i.e, not an array)
1839
+ # possibly holding differentiable values (for which gradients must be accumulated)
1840
+ return type_scalar_type(var_type) in float_types or isinstance(var_type, Struct)
1841
+
1842
+ def emit_Attribute(adj, node):
1843
+ if hasattr(node, "is_adjoint"):
1844
+ node.value.is_adjoint = True
1845
+
1846
+ aggregate = adj.eval(node.value)
1847
+
1848
+ try:
1849
+ if isinstance(aggregate, types.ModuleType) or isinstance(aggregate, type):
1850
+ out = getattr(aggregate, node.attr)
1851
+
1852
+ if warp.types.is_value(out):
1853
+ return adj.add_constant(out)
1854
+
1855
+ return out
1856
+
1857
+ if hasattr(node, "is_adjoint"):
1858
+ # create a Var that points to the struct attribute, i.e.: directly generates `struct.attr` when used
1859
+ attr_name = aggregate.label + "." + node.attr
1860
+ attr_type = aggregate.type.vars[node.attr].type
1861
+
1862
+ return Var(attr_name, attr_type)
1863
+
1864
+ aggregate_type = strip_reference(aggregate.type)
1865
+
1866
+ # reading a vector component
1867
+ if type_is_vector(aggregate_type):
1868
+ index = adj.vector_component_index(node.attr, aggregate_type)
1869
+
1870
+ return adj.add_builtin_call("extract", [aggregate, index])
1871
+
1872
+ else:
1873
+ attr_type = Reference(aggregate_type.vars[node.attr].type)
1874
+ attr = adj.add_var(attr_type)
1875
+
1876
+ if is_reference(aggregate.type):
1877
+ adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}->{node.attr});")
1878
+ else:
1879
+ adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}.{node.attr});")
1880
+
1881
+ if adj.is_differentiable_value_type(strip_reference(attr_type)):
1882
+ adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} += {attr.emit_adj()};")
1883
+ else:
1884
+ adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} = {attr.emit_adj()};")
1885
+
1886
+ return attr
1887
+
1888
+ except (KeyError, AttributeError) as e:
1889
+ # Try resolving as type attribute
1890
+ aggregate_type = strip_reference(aggregate.type) if isinstance(aggregate, Var) else aggregate
1891
+
1892
+ type_attribute = adj.resolve_type_attribute(aggregate_type, node.attr)
1893
+ if type_attribute is not None:
1894
+ return type_attribute
1895
+
1896
+ if isinstance(aggregate, Var):
1897
+ raise WarpCodegenAttributeError(
1898
+ f"Error, `{node.attr}` is not an attribute of '{node.value.id}' ({type_repr(aggregate.type)})"
1899
+ ) from e
1900
+ raise WarpCodegenAttributeError(f"Error, `{node.attr}` is not an attribute of '{aggregate}'") from e
1901
+
1902
+ def emit_Assert(adj, node):
1903
+ # eval condition
1904
+ cond = adj.eval(node.test)
1905
+ cond = adj.load(cond)
1906
+
1907
+ source_segment = ast.get_source_segment(adj.source, node)
1908
+ # If a message was provided with the assert, " marks can interfere with the generated code
1909
+ escaped_segment = source_segment.replace('"', '\\"')
1910
+
1911
+ adj.add_forward(f'assert(("{escaped_segment}",{cond.emit()}));')
1912
+
1913
+ def emit_Constant(adj, node):
1914
+ if node.value is None:
1915
+ raise WarpCodegenTypeError("None type unsupported")
1916
+ else:
1917
+ return adj.add_constant(node.value)
1918
+
1919
+ def emit_BinOp(adj, node):
1920
+ # evaluate binary operator arguments
1921
+
1922
+ if warp.config.verify_autograd_array_access:
1923
+ # array overwrite tracking: in-place operators are a special case
1924
+ # x[tid] = x[tid] + 1 is a read followed by a write, but we only want to record the write
1925
+ # so we save the current arg read flags and restore them after lhs eval
1926
+ is_read_states = []
1927
+ for arg in adj.args:
1928
+ is_read_states.append(arg.is_read)
1929
+
1930
+ # evaluate lhs binary operator argument
1931
+ left = adj.eval(node.left)
1932
+
1933
+ if warp.config.verify_autograd_array_access:
1934
+ # restore arg read flags
1935
+ for i, arg in enumerate(adj.args):
1936
+ arg.is_read = is_read_states[i]
1937
+
1938
+ # evaluate rhs binary operator argument
1939
+ right = adj.eval(node.right)
1940
+
1941
+ name = builtin_operators[type(node.op)]
1942
+
1943
+ try:
1944
+ # Check if there is any user-defined overload for this operator
1945
+ user_func = adj.resolve_external_reference(name)
1946
+ if isinstance(user_func, warp.context.Function):
1947
+ return adj.add_call(user_func, (left, right), {}, {})
1948
+ except WarpCodegenError:
1949
+ pass
1950
+
1951
+ return adj.add_builtin_call(name, [left, right])
1952
+
1953
+ def emit_UnaryOp(adj, node):
1954
+ # evaluate unary op arguments
1955
+ arg = adj.eval(node.operand)
1956
+
1957
+ # evaluate expression to a compile-time constant if arg is a constant
1958
+ if arg.constant is not None and math.isfinite(arg.constant):
1959
+ if isinstance(node.op, ast.USub):
1960
+ return adj.add_constant(-arg.constant)
1961
+
1962
+ name = builtin_operators[type(node.op)]
1963
+
1964
+ return adj.add_builtin_call(name, [arg])
1965
+
1966
+ def materialize_redefinitions(adj, symbols):
1967
+ # detect symbols with conflicting definitions (assigned inside the for loop)
1968
+ for items in symbols.items():
1969
+ sym = items[0]
1970
+ if adj.is_constant_iter_symbol(sym):
1971
+ # ignore constant overwriting in for-loops if it is a loop iterator
1972
+ # (it is no problem to unroll static loops multiple times in sequence)
1973
+ continue
1974
+
1975
+ var1 = items[1]
1976
+ var2 = adj.symbols[sym]
1977
+
1978
+ if var1 != var2:
1979
+ if warp.config.verbose and not adj.custom_reverse_mode:
1980
+ lineno = adj.lineno + adj.fun_lineno
1981
+ line = adj.source_lines[adj.lineno]
1982
+ msg = f'Warning: detected mutated variable {sym} during a dynamic for-loop in function "{adj.fun_name}" at {adj.filename}:{lineno}: this may not be a differentiable operation.\n{line}\n'
1983
+ print(msg)
1984
+
1985
+ if var1.constant is not None:
1986
+ raise WarpCodegenError(
1987
+ f"Error mutating a constant {sym} inside a dynamic loop, use the following syntax: pi = float(3.141) to declare a dynamic variable"
1988
+ )
1989
+
1990
+ # overwrite the old variable value (violates SSA)
1991
+ adj.add_builtin_call("assign", [var1, var2])
1992
+
1993
+ # reset the symbol to point to the original variable
1994
+ adj.symbols[sym] = var1
1995
+
1996
+ def emit_While(adj, node):
1997
+ adj.begin_while(node.test)
1998
+
1999
+ adj.loop_symbols.append(adj.symbols.copy())
2000
+
2001
+ # eval body
2002
+ for s in node.body:
2003
+ adj.eval(s)
2004
+
2005
+ adj.materialize_redefinitions(adj.loop_symbols[-1])
2006
+ adj.loop_symbols.pop()
2007
+
2008
+ adj.end_while()
2009
+
2010
+ def eval_num(adj, a):
2011
+ if isinstance(a, ast.Constant):
2012
+ return True, a.value
2013
+ if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Constant):
2014
+ # Negative constant
2015
+ return True, -a.operand.value
2016
+
2017
+ # try and resolve the expression to an object
2018
+ # e.g.: wp.constant in the globals scope
2019
+ obj, _ = adj.resolve_static_expression(a)
2020
+
2021
+ if obj is None:
2022
+ obj = adj.eval(a)
2023
+
2024
+ if isinstance(obj, Var) and obj.constant is not None:
2025
+ obj = obj.constant
2026
+
2027
+ return warp.types.is_int(obj), obj
2028
+
2029
+ # detects whether a loop contains a break (or continue) statement
2030
+ def contains_break(adj, body):
2031
+ for s in body:
2032
+ if isinstance(s, ast.Break):
2033
+ return True
2034
+ elif isinstance(s, ast.Continue):
2035
+ return True
2036
+ elif isinstance(s, ast.If):
2037
+ if adj.contains_break(s.body):
2038
+ return True
2039
+ if adj.contains_break(s.orelse):
2040
+ return True
2041
+ else:
2042
+ # note that nested for or while loops containing a break statement
2043
+ # do not affect the current loop
2044
+ pass
2045
+
2046
+ return False
2047
+
2048
+ # returns a constant range() if unrollable, otherwise None
2049
+ def get_unroll_range(adj, loop):
2050
+ if (
2051
+ not isinstance(loop.iter, ast.Call)
2052
+ or not isinstance(loop.iter.func, ast.Name)
2053
+ or loop.iter.func.id != "range"
2054
+ or len(loop.iter.args) == 0
2055
+ or len(loop.iter.args) > 3
2056
+ ):
2057
+ return None
2058
+
2059
+ # if all range() arguments are numeric constants we will unroll
2060
+ # note that this only handles trivial constants, it will not unroll
2061
+ # constant compile-time expressions e.g.: range(0, 3*2)
2062
+
2063
+ # Evaluate the arguments and check that they are numeric constants
2064
+ # It is important to do that in one pass, so that if evaluating these arguments have side effects
2065
+ # the code does not get generated more than once
2066
+ range_args = [adj.eval_num(arg) for arg in loop.iter.args]
2067
+ arg_is_numeric, arg_values = zip(*range_args)
2068
+
2069
+ if all(arg_is_numeric):
2070
+ # All argument are numeric constants
2071
+
2072
+ # range(end)
2073
+ if len(loop.iter.args) == 1:
2074
+ start = 0
2075
+ end = arg_values[0]
2076
+ step = 1
2077
+
2078
+ # range(start, end)
2079
+ elif len(loop.iter.args) == 2:
2080
+ start = arg_values[0]
2081
+ end = arg_values[1]
2082
+ step = 1
2083
+
2084
+ # range(start, end, step)
2085
+ elif len(loop.iter.args) == 3:
2086
+ start = arg_values[0]
2087
+ end = arg_values[1]
2088
+ step = arg_values[2]
2089
+
2090
+ # test if we're above max unroll count
2091
+ max_iters = abs(end - start) // abs(step)
2092
+
2093
+ if "max_unroll" in adj.builder_options:
2094
+ max_unroll = adj.builder_options["max_unroll"]
2095
+ else:
2096
+ max_unroll = warp.config.max_unroll
2097
+
2098
+ ok_to_unroll = True
2099
+
2100
+ if max_iters > max_unroll:
2101
+ if warp.config.verbose:
2102
+ print(
2103
+ f"Warning: fixed-size loop count of {max_iters} is larger than the module 'max_unroll' limit of {max_unroll}, will generate dynamic loop."
2104
+ )
2105
+ ok_to_unroll = False
2106
+
2107
+ elif adj.contains_break(loop.body):
2108
+ if warp.config.verbose:
2109
+ print("Warning: 'break' or 'continue' found in loop body, will generate dynamic loop.")
2110
+ ok_to_unroll = False
2111
+
2112
+ if ok_to_unroll:
2113
+ return range(start, end, step)
2114
+
2115
+ # Unroll is not possible, range needs to be valuated dynamically
2116
+ range_call = adj.add_builtin_call(
2117
+ "range",
2118
+ [adj.add_constant(val) if is_numeric else val for is_numeric, val in range_args],
2119
+ )
2120
+ return range_call
2121
+
2122
+ def record_constant_iter_symbol(adj, sym):
2123
+ adj.loop_const_iter_symbols.add(sym)
2124
+
2125
+ def is_constant_iter_symbol(adj, sym):
2126
+ return sym in adj.loop_const_iter_symbols
2127
+
2128
+ def emit_For(adj, node):
2129
+ # try and unroll simple range() statements that use constant args
2130
+ unroll_range = adj.get_unroll_range(node)
2131
+
2132
+ if isinstance(unroll_range, range):
2133
+ const_iter_sym = node.target.id
2134
+ # prevent constant conflicts in `materialize_redefinitions()`
2135
+ adj.record_constant_iter_symbol(const_iter_sym)
2136
+
2137
+ # unroll static for-loop
2138
+ for i in unroll_range:
2139
+ const_iter = adj.add_constant(i)
2140
+ adj.symbols[const_iter_sym] = const_iter
2141
+
2142
+ # eval body
2143
+ for s in node.body:
2144
+ adj.eval(s)
2145
+
2146
+ # otherwise generate a dynamic loop
2147
+ else:
2148
+ # evaluate the Iterable -- only if not previously evaluated when trying to unroll
2149
+ if unroll_range is not None:
2150
+ # Range has already been evaluated when trying to unroll, do not re-evaluate
2151
+ iter = unroll_range
2152
+ else:
2153
+ iter = adj.eval(node.iter)
2154
+
2155
+ adj.symbols[node.target.id] = adj.begin_for(iter)
2156
+
2157
+ # for loops should be side-effect free, here we store a copy
2158
+ adj.loop_symbols.append(adj.symbols.copy())
2159
+
2160
+ # eval body
2161
+ for s in node.body:
2162
+ adj.eval(s)
2163
+
2164
+ adj.materialize_redefinitions(adj.loop_symbols[-1])
2165
+ adj.loop_symbols.pop()
2166
+
2167
+ adj.end_for(iter)
2168
+
2169
+ def emit_Break(adj, node):
2170
+ adj.materialize_redefinitions(adj.loop_symbols[-1])
2171
+
2172
+ adj.add_forward(f"goto end_{adj.loop_blocks[-1].label};")
2173
+
2174
+ def emit_Continue(adj, node):
2175
+ adj.materialize_redefinitions(adj.loop_symbols[-1])
2176
+
2177
+ adj.add_forward(f"goto start_{adj.loop_blocks[-1].label};")
2178
+
2179
+ def emit_Expr(adj, node):
2180
+ return adj.eval(node.value)
2181
+
2182
+ def check_tid_in_func_error(adj, node):
2183
+ if adj.is_user_function:
2184
+ if hasattr(node.func, "attr") and node.func.attr == "tid":
2185
+ lineno = adj.lineno + adj.fun_lineno
2186
+ line = adj.source_lines[adj.lineno]
2187
+ raise WarpCodegenError(
2188
+ "tid() may only be called from a Warp kernel, not a Warp function. "
2189
+ "Instead, obtain the indices from a @wp.kernel and pass them as "
2190
+ f"arguments to the function {adj.fun_name}, {adj.filename}:{lineno}:\n{line}\n"
2191
+ )
2192
+
2193
+ def resolve_arg(adj, arg):
2194
+ # Always try to start with evaluating the argument since it can help
2195
+ # detecting some issues such as global variables being accessed.
2196
+ try:
2197
+ var = adj.eval(arg)
2198
+ except (WarpCodegenError, WarpCodegenKeyError) as e:
2199
+ error = e
2200
+ else:
2201
+ error = None
2202
+
2203
+ # Check if we can resolve the argument as a static expression.
2204
+ # If not, return the variable resulting from evaluating the argument.
2205
+ expr, _ = adj.resolve_static_expression(arg)
2206
+ if expr is None:
2207
+ if error is not None:
2208
+ raise error
2209
+
2210
+ return var
2211
+
2212
+ if isinstance(expr, (type, Var, warp.context.Function)):
2213
+ return expr
2214
+
2215
+ return adj.add_constant(expr)
2216
+
2217
+ def emit_Call(adj, node):
2218
+ adj.check_tid_in_func_error(node)
2219
+
2220
+ # try and lookup function in globals by
2221
+ # resolving path (e.g.: module.submodule.attr)
2222
+ if hasattr(node.func, "warp_func"):
2223
+ func = node.func.warp_func
2224
+ path = []
2225
+ else:
2226
+ func, path = adj.resolve_static_expression(node.func)
2227
+ if func is None:
2228
+ func = adj.eval(node.func)
2229
+
2230
+ if adj.is_static_expression(func):
2231
+ # try to evaluate wp.static() expressions
2232
+ obj, _ = adj.evaluate_static_expression(node)
2233
+ if obj is not None:
2234
+ if isinstance(obj, warp.context.Function):
2235
+ # special handling for wp.static() evaluating to a function
2236
+ return obj
2237
+ else:
2238
+ out = adj.add_constant(obj)
2239
+ return out
2240
+
2241
+ type_args = {}
2242
+
2243
+ if len(path) > 0 and not isinstance(func, warp.context.Function):
2244
+ attr = path[-1]
2245
+ caller = func
2246
+ func = None
2247
+
2248
+ # try and lookup function name in builtins (e.g.: using `dot` directly without wp prefix)
2249
+ if attr in warp.context.builtin_functions:
2250
+ func = warp.context.builtin_functions[attr]
2251
+
2252
+ # vector class type e.g.: wp.vec3f constructor
2253
+ if func is None and hasattr(caller, "_wp_generic_type_str_"):
2254
+ func = warp.context.builtin_functions.get(caller._wp_constructor_)
2255
+
2256
+ # scalar class type e.g.: wp.int8 constructor
2257
+ if func is None and hasattr(caller, "__name__") and caller.__name__ in warp.context.builtin_functions:
2258
+ func = warp.context.builtin_functions.get(caller.__name__)
2259
+
2260
+ # struct constructor
2261
+ if func is None and isinstance(caller, Struct):
2262
+ adj.builder.build_struct_recursive(caller)
2263
+ if node.args or node.keywords:
2264
+ func = caller.value_constructor
2265
+ else:
2266
+ func = caller.default_constructor
2267
+
2268
+ if hasattr(caller, "_wp_type_args_"):
2269
+ type_args = caller._wp_type_args_
2270
+
2271
+ if func is None:
2272
+ raise WarpCodegenError(
2273
+ f"Could not find function {'.'.join(path)} as a built-in or user-defined function. Note that user functions must be annotated with a @wp.func decorator to be called from a kernel."
2274
+ )
2275
+
2276
+ # Check if any argument correspond to an unsupported construct.
2277
+ # Tuples are supported in the context of assigning multiple variables
2278
+ # at once, but not in place of vectors when calling built-ins like
2279
+ # `wp.length((1, 2, 3))`.
2280
+ # Therefore, we need to catch this specific case here instead of
2281
+ # more generally in `adj.eval()`.
2282
+ for arg in node.args:
2283
+ if isinstance(arg, ast.Tuple):
2284
+ raise WarpCodegenError(
2285
+ "Tuple constructs are not supported in kernels. Use vectors like `wp.vec3()` instead."
2286
+ )
2287
+
2288
+ # get expected return count, e.g.: for multi-assignment
2289
+ min_outputs = None
2290
+ if hasattr(node, "expects"):
2291
+ min_outputs = node.expects
2292
+
2293
+ # Evaluate all positional and keywords arguments.
2294
+ args = tuple(adj.resolve_arg(x) for x in node.args)
2295
+ kwargs = {x.arg: adj.resolve_arg(x.value) for x in node.keywords}
2296
+
2297
+ # add the call and build the callee adjoint if needed (func.adj)
2298
+ out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
2299
+
2300
+ if warp.config.verify_autograd_array_access:
2301
+ # Extract the types and values passed as arguments to the function call.
2302
+ arg_types = tuple(strip_reference(get_arg_type(x)) for x in args)
2303
+ kwarg_types = {k: strip_reference(get_arg_type(v)) for k, v in kwargs.items()}
2304
+
2305
+ # Resolve the exact function signature among any existing overload.
2306
+ resolved_func = adj.resolve_func(func, arg_types, kwarg_types, min_outputs)
2307
+
2308
+ # update arg read/write states according to what happens to that arg in the called function
2309
+ if hasattr(resolved_func, "adj"):
2310
+ for i, arg in enumerate(args):
2311
+ if resolved_func.adj.args[i].is_write:
2312
+ kernel_name = adj.fun_name
2313
+ filename = adj.filename
2314
+ lineno = adj.lineno + adj.fun_lineno
2315
+ arg.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2316
+ if resolved_func.adj.args[i].is_read:
2317
+ arg.mark_read()
2318
+
2319
+ return out
2320
+
2321
+ def emit_Index(adj, node):
2322
+ # the ast.Index node appears in 3.7 versions
2323
+ # when performing array slices, e.g.: x = arr[i]
2324
+ # but in version 3.8 and higher it does not appear
2325
+
2326
+ if hasattr(node, "is_adjoint"):
2327
+ node.value.is_adjoint = True
2328
+
2329
+ return adj.eval(node.value)
2330
+
2331
+ # returns the object being indexed, and the list of indices
2332
+ def eval_subscript(adj, node):
2333
+ # We want to coalesce multi-dimensional array indexing into a single operation. This needs to deal with expressions like `a[i][j][x][y]` where `a` is a 2D array of matrices,
2334
+ # and essentially rewrite it into `a[i, j][x][y]`. Since the AST observes the indexing right-to-left, and we don't want to evaluate the index expressions prematurely,
2335
+ # this requires a first loop to check if this `node` only performs indexing on the array, and a second loop to evaluate and collect index variables.
2336
+ root = node
2337
+ count = 0
2338
+ array = None
2339
+ while isinstance(root, ast.Subscript):
2340
+ if isinstance(root.slice, ast.Tuple):
2341
+ # handles the x[i, j] case (Python 3.8.x upward)
2342
+ count += len(root.slice.elts)
2343
+ elif isinstance(root.slice, ast.Index) and isinstance(root.slice.value, ast.Tuple):
2344
+ # handles the x[i, j] case (Python 3.7.x)
2345
+ count += len(root.slice.value.elts)
2346
+ else:
2347
+ # simple expression, e.g.: x[i]
2348
+ count += 1
2349
+
2350
+ if isinstance(root.value, ast.Name):
2351
+ symbol = adj.emit_Name(root.value)
2352
+ symbol_type = strip_reference(symbol.type)
2353
+ if is_array(symbol_type):
2354
+ array = symbol
2355
+ break
2356
+
2357
+ root = root.value
2358
+
2359
+ # If not all indices index into the array, just evaluate the right-most indexing operation.
2360
+ if not array or (count > array.type.ndim):
2361
+ count = 1
2362
+
2363
+ indices = []
2364
+ root = node
2365
+ while len(indices) < count:
2366
+ if isinstance(root.slice, ast.Tuple):
2367
+ ij = [adj.eval(arg) for arg in root.slice.elts]
2368
+ elif isinstance(root.slice, ast.Index) and isinstance(root.slice.value, ast.Tuple):
2369
+ ij = [adj.eval(arg) for arg in root.slice.value.elts]
2370
+ else:
2371
+ ij = [adj.eval(root.slice)]
2372
+
2373
+ indices = ij + indices # prepend
2374
+
2375
+ root = root.value
2376
+
2377
+ target = adj.eval(root)
2378
+
2379
+ return target, indices
2380
+
2381
+ def emit_Subscript(adj, node):
2382
+ if hasattr(node.value, "attr") and node.value.attr == "adjoint":
2383
+ # handle adjoint of a variable, i.e. wp.adjoint[var]
2384
+ node.slice.is_adjoint = True
2385
+ var = adj.eval(node.slice)
2386
+ var_name = var.label
2387
+ var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
2388
+ return var
2389
+
2390
+ target, indices = adj.eval_subscript(node)
2391
+
2392
+ target_type = strip_reference(target.type)
2393
+ if is_array(target_type):
2394
+ if len(indices) == target_type.ndim:
2395
+ # handles array loads (where each dimension has an index specified)
2396
+ out = adj.add_builtin_call("address", [target, *indices])
2397
+
2398
+ if warp.config.verify_autograd_array_access:
2399
+ target.mark_read()
2400
+
2401
+ else:
2402
+ # handles array views (fewer indices than dimensions)
2403
+ out = adj.add_builtin_call("view", [target, *indices])
2404
+
2405
+ if warp.config.verify_autograd_array_access:
2406
+ # store reference to target Var to propagate downstream read/write state back to root arg Var
2407
+ out.parent = target
2408
+
2409
+ # view arg inherits target Var's read/write states
2410
+ out.is_read = target.is_read
2411
+ out.is_write = target.is_write
2412
+
2413
+ elif is_tile(target_type):
2414
+ if len(indices) == len(target_type.shape):
2415
+ # handles extracting a single element from a tile
2416
+ out = adj.add_builtin_call("tile_extract", [target, *indices])
2417
+ elif len(indices) < len(target_type.shape):
2418
+ # handles tile views
2419
+ out = adj.add_builtin_call("tile_view", [target, indices])
2420
+ else:
2421
+ raise RuntimeError(
2422
+ f"Incorrect number of indices specified for a tile view/extract, got {len(indices)} indices for a {len(target_type.shape)} dimensional tile."
2423
+ )
2424
+
2425
+ else:
2426
+ # handles non-array type indexing, e.g: vec3, mat33, etc
2427
+ out = adj.add_builtin_call("extract", [target, *indices])
2428
+
2429
+ return out
2430
+
2431
+ def emit_Assign(adj, node):
2432
+ if len(node.targets) != 1:
2433
+ raise WarpCodegenError("Assigning the same value to multiple variables is not supported")
2434
+
2435
+ lhs = node.targets[0]
2436
+
2437
+ if not isinstance(lhs, ast.Tuple):
2438
+ # Check if the rhs corresponds to an unsupported construct.
2439
+ # Tuples are supported in the context of assigning multiple variables
2440
+ # at once, but not for simple assignments like `x = (1, 2, 3)`.
2441
+ # Therefore, we need to catch this specific case here instead of
2442
+ # more generally in `adj.eval()`.
2443
+ if isinstance(node.value, ast.List):
2444
+ raise WarpCodegenError(
2445
+ "List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
2446
+ )
2447
+ elif isinstance(node.value, ast.Tuple):
2448
+ raise WarpCodegenError(
2449
+ "Tuple constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
2450
+ )
2451
+
2452
+ # handle the case where we are assigning multiple output variables
2453
+ if isinstance(lhs, ast.Tuple):
2454
+ # record the expected number of outputs on the node
2455
+ # we do this so we can decide which function to
2456
+ # call based on the number of expected outputs
2457
+ if isinstance(node.value, ast.Call):
2458
+ node.value.expects = len(lhs.elts)
2459
+
2460
+ # evaluate values
2461
+ if isinstance(node.value, ast.Tuple):
2462
+ out = [adj.eval(v) for v in node.value.elts]
2463
+ else:
2464
+ out = adj.eval(node.value)
2465
+
2466
+ names = []
2467
+ for v in lhs.elts:
2468
+ if isinstance(v, ast.Name):
2469
+ names.append(v.id)
2470
+ else:
2471
+ raise WarpCodegenError(
2472
+ "Multiple return functions can only assign to simple variables, e.g.: x, y = func()"
2473
+ )
2474
+
2475
+ if len(names) != len(out):
2476
+ raise WarpCodegenError(
2477
+ f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(out)}, got {len(names)})"
2478
+ )
2479
+
2480
+ for name, rhs in zip(names, out):
2481
+ if name in adj.symbols:
2482
+ if not types_equal(rhs.type, adj.symbols[name].type):
2483
+ raise WarpCodegenTypeError(
2484
+ f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
2485
+ )
2486
+
2487
+ adj.symbols[name] = rhs
2488
+
2489
+ # handles the case where we are assigning to an array index (e.g.: arr[i] = 2.0)
2490
+ elif isinstance(lhs, ast.Subscript):
2491
+ rhs = adj.eval(node.value)
2492
+
2493
+ if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
2494
+ # handle adjoint of a variable, i.e. wp.adjoint[var]
2495
+ lhs.slice.is_adjoint = True
2496
+ src_var = adj.eval(lhs.slice)
2497
+ var = Var(f"adj_{src_var.label}", type=src_var.type, constant=None, prefix=False)
2498
+ adj.add_forward(f"{var.emit()} = {rhs.emit()};")
2499
+ return
2500
+
2501
+ target, indices = adj.eval_subscript(lhs)
2502
+
2503
+ target_type = strip_reference(target.type)
2504
+
2505
+ if is_array(target_type):
2506
+ adj.add_builtin_call("array_store", [target, *indices, rhs])
2507
+
2508
+ if warp.config.verify_autograd_array_access:
2509
+ kernel_name = adj.fun_name
2510
+ filename = adj.filename
2511
+ lineno = adj.lineno + adj.fun_lineno
2512
+
2513
+ target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2514
+
2515
+ elif is_tile(target_type):
2516
+ adj.add_builtin_call("assign", [target, *indices, rhs])
2517
+
2518
+ elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
2519
+ # recursively unwind AST, stopping at penultimate node
2520
+ node = lhs
2521
+ while hasattr(node, "value"):
2522
+ if hasattr(node.value, "value"):
2523
+ node = node.value
2524
+ else:
2525
+ break
2526
+ # lhs is updating a variable adjoint (i.e. wp.adjoint[var])
2527
+ if hasattr(node, "attr") and node.attr == "adjoint":
2528
+ attr = adj.add_builtin_call("index", [target, *indices])
2529
+ adj.add_builtin_call("store", [attr, rhs])
2530
+ return
2531
+
2532
+ # TODO: array vec component case
2533
+ if is_reference(target.type):
2534
+ attr = adj.add_builtin_call("indexref", [target, *indices])
2535
+ adj.add_builtin_call("store", [attr, rhs])
2536
+
2537
+ if warp.config.verbose and not adj.custom_reverse_mode:
2538
+ lineno = adj.lineno + adj.fun_lineno
2539
+ line = adj.source_lines[adj.lineno]
2540
+ node_source = adj.get_node_source(lhs.value)
2541
+ print(
2542
+ f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
2543
+ )
2544
+ else:
2545
+ if warp.config.enable_vector_component_overwrites:
2546
+ out = adj.add_builtin_call("assign_copy", [target, *indices, rhs])
2547
+
2548
+ # re-point target symbol to out var
2549
+ for id in adj.symbols:
2550
+ if adj.symbols[id] == target:
2551
+ adj.symbols[id] = out
2552
+ break
2553
+ else:
2554
+ adj.add_builtin_call("assign_inplace", [target, *indices, rhs])
2555
+
2556
+ else:
2557
+ raise WarpCodegenError(
2558
+ f"Can only subscript assign array, vector, quaternion, and matrix types, got {target_type}"
2559
+ )
2560
+
2561
+ elif isinstance(lhs, ast.Name):
2562
+ # symbol name
2563
+ name = lhs.id
2564
+
2565
+ # evaluate rhs
2566
+ rhs = adj.eval(node.value)
2567
+
2568
+ # check type matches if symbol already defined
2569
+ if name in adj.symbols:
2570
+ if not types_equal(strip_reference(rhs.type), adj.symbols[name].type):
2571
+ raise WarpCodegenTypeError(
2572
+ f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
2573
+ )
2574
+
2575
+ # handle simple assignment case (a = b), where we generate a value copy rather than reference
2576
+ if isinstance(node.value, ast.Name) or is_reference(rhs.type):
2577
+ out = adj.add_builtin_call("copy", [rhs])
2578
+ else:
2579
+ out = rhs
2580
+
2581
+ # update symbol map (assumes lhs is a Name node)
2582
+ adj.symbols[name] = out
2583
+
2584
+ elif isinstance(lhs, ast.Attribute):
2585
+ rhs = adj.eval(node.value)
2586
+ aggregate = adj.eval(lhs.value)
2587
+ aggregate_type = strip_reference(aggregate.type)
2588
+
2589
+ # assigning to a vector or quaternion component
2590
+ if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
2591
+ index = adj.vector_component_index(lhs.attr, aggregate_type)
2592
+
2593
+ if is_reference(aggregate.type):
2594
+ attr = adj.add_builtin_call("indexref", [aggregate, index])
2595
+ adj.add_builtin_call("store", [attr, rhs])
2596
+ else:
2597
+ if warp.config.enable_vector_component_overwrites:
2598
+ out = adj.add_builtin_call("assign_copy", [aggregate, index, rhs])
2599
+
2600
+ # re-point target symbol to out var
2601
+ for id in adj.symbols:
2602
+ if adj.symbols[id] == aggregate:
2603
+ adj.symbols[id] = out
2604
+ break
2605
+ else:
2606
+ adj.add_builtin_call("assign_inplace", [aggregate, index, rhs])
2607
+
2608
+ else:
2609
+ attr = adj.emit_Attribute(lhs)
2610
+ if is_reference(attr.type):
2611
+ adj.add_builtin_call("store", [attr, rhs])
2612
+ else:
2613
+ adj.add_builtin_call("assign", [attr, rhs])
2614
+
2615
+ if warp.config.verbose and not adj.custom_reverse_mode:
2616
+ lineno = adj.lineno + adj.fun_lineno
2617
+ line = adj.source_lines[adj.lineno]
2618
+ msg = f'Warning: detected mutated struct {attr.label} during function "{adj.fun_name}" at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n'
2619
+ print(msg)
2620
+
2621
+ else:
2622
+ raise WarpCodegenError("Error, unsupported assignment statement.")
2623
+
2624
+ def emit_Return(adj, node):
2625
+ if node.value is None:
2626
+ var = None
2627
+ elif isinstance(node.value, ast.Tuple):
2628
+ var = tuple(adj.eval(arg) for arg in node.value.elts)
2629
+ else:
2630
+ var = (adj.eval(node.value),)
2631
+
2632
+ if adj.return_var is not None:
2633
+ old_ctypes = tuple(v.ctype(value_type=True) for v in adj.return_var)
2634
+ new_ctypes = tuple(v.ctype(value_type=True) for v in var)
2635
+ if old_ctypes != new_ctypes:
2636
+ raise WarpCodegenTypeError(
2637
+ f"Error, function returned different types, previous: [{', '.join(old_ctypes)}], new [{', '.join(new_ctypes)}]"
2638
+ )
2639
+
2640
+ if var is not None:
2641
+ adj.return_var = ()
2642
+ for ret in var:
2643
+ if is_reference(ret.type):
2644
+ ret_var = adj.add_builtin_call("copy", [ret])
2645
+ else:
2646
+ ret_var = ret
2647
+ adj.return_var += (ret_var,)
2648
+
2649
+ adj.add_return(adj.return_var)
2650
+
2651
+ def emit_AugAssign(adj, node):
2652
+ lhs = node.target
2653
+
2654
+ # replace augmented assignment with assignment statement + binary op (default behaviour)
2655
+ def make_new_assign_statement():
2656
+ new_node = ast.Assign(targets=[lhs], value=ast.BinOp(lhs, node.op, node.value))
2657
+ adj.eval(new_node)
2658
+
2659
+ if isinstance(lhs, ast.Subscript):
2660
+ rhs = adj.eval(node.value)
2661
+
2662
+ # wp.adjoint[var] appears in custom grad functions, and does not require
2663
+ # special consideration in the AugAssign case
2664
+ if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
2665
+ make_new_assign_statement()
2666
+ return
2667
+
2668
+ target, indices = adj.eval_subscript(lhs)
2669
+
2670
+ target_type = strip_reference(target.type)
2671
+
2672
+ if is_array(target_type):
2673
+ # target_types int8, uint8, int16, uint16 are not suitable for atomic array accumulation
2674
+ if target_type.dtype in warp.types.non_atomic_types:
2675
+ make_new_assign_statement()
2676
+ return
2677
+
2678
+ # the same holds true for vecs/mats/quats that are composed of these types
2679
+ if (
2680
+ type_is_vector(target_type.dtype)
2681
+ or type_is_quaternion(target_type.dtype)
2682
+ or type_is_matrix(target_type.dtype)
2683
+ ):
2684
+ dtype = getattr(target_type.dtype, "_wp_scalar_type_", None)
2685
+ if dtype in warp.types.non_atomic_types:
2686
+ make_new_assign_statement()
2687
+ return
2688
+
2689
+ kernel_name = adj.fun_name
2690
+ filename = adj.filename
2691
+ lineno = adj.lineno + adj.fun_lineno
2692
+
2693
+ if isinstance(node.op, ast.Add):
2694
+ adj.add_builtin_call("atomic_add", [target, *indices, rhs])
2695
+
2696
+ if warp.config.verify_autograd_array_access:
2697
+ target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2698
+
2699
+ elif isinstance(node.op, ast.Sub):
2700
+ adj.add_builtin_call("atomic_sub", [target, *indices, rhs])
2701
+
2702
+ if warp.config.verify_autograd_array_access:
2703
+ target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2704
+ else:
2705
+ if warp.config.verbose:
2706
+ print(f"Warning: in-place op {node.op} is not differentiable")
2707
+ make_new_assign_statement()
2708
+ return
2709
+
2710
+ elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
2711
+ if isinstance(node.op, ast.Add):
2712
+ adj.add_builtin_call("add_inplace", [target, *indices, rhs])
2713
+ elif isinstance(node.op, ast.Sub):
2714
+ adj.add_builtin_call("sub_inplace", [target, *indices, rhs])
2715
+ else:
2716
+ if warp.config.verbose:
2717
+ print(f"Warning: in-place op {node.op} is not differentiable")
2718
+ make_new_assign_statement()
2719
+ return
2720
+
2721
+ else:
2722
+ raise WarpCodegenError("Can only subscript in-place assign array, vector, quaternion, and matrix types")
2723
+
2724
+ # TODO
2725
+ elif isinstance(lhs, ast.Attribute):
2726
+ make_new_assign_statement()
2727
+ return
2728
+
2729
+ else:
2730
+ make_new_assign_statement()
2731
+ return
2732
+
2733
+ def emit_Tuple(adj, node):
2734
+ # LHS for expressions, such as i, j, k = 1, 2, 3
2735
+ return tuple(adj.eval(x) for x in node.elts)
2736
+
2737
+ def emit_Pass(adj, node):
2738
+ pass
2739
+
2740
+ node_visitors = {
2741
+ ast.FunctionDef: emit_FunctionDef,
2742
+ ast.If: emit_If,
2743
+ ast.Compare: emit_Compare,
2744
+ ast.BoolOp: emit_BoolOp,
2745
+ ast.Name: emit_Name,
2746
+ ast.Attribute: emit_Attribute,
2747
+ ast.Constant: emit_Constant,
2748
+ ast.BinOp: emit_BinOp,
2749
+ ast.UnaryOp: emit_UnaryOp,
2750
+ ast.While: emit_While,
2751
+ ast.For: emit_For,
2752
+ ast.Break: emit_Break,
2753
+ ast.Continue: emit_Continue,
2754
+ ast.Expr: emit_Expr,
2755
+ ast.Call: emit_Call,
2756
+ ast.Index: emit_Index, # Deprecated in 3.9
2757
+ ast.Subscript: emit_Subscript,
2758
+ ast.Assign: emit_Assign,
2759
+ ast.Return: emit_Return,
2760
+ ast.AugAssign: emit_AugAssign,
2761
+ ast.Tuple: emit_Tuple,
2762
+ ast.Pass: emit_Pass,
2763
+ ast.Assert: emit_Assert,
2764
+ }
2765
+
2766
+ def eval(adj, node):
2767
+ if hasattr(node, "lineno"):
2768
+ adj.set_lineno(node.lineno - 1)
2769
+
2770
+ try:
2771
+ emit_node = adj.node_visitors[type(node)]
2772
+ except KeyError as e:
2773
+ type_name = type(node).__name__
2774
+ namespace = "ast." if isinstance(node, ast.AST) else ""
2775
+ raise WarpCodegenError(f"Construct `{namespace}{type_name}` not supported in kernels.") from e
2776
+
2777
+ return emit_node(adj, node)
2778
+
2779
+ # helper to evaluate expressions of the form
2780
+ # obj1.obj2.obj3.attr in the function's global scope
2781
+ def resolve_path(adj, path):
2782
+ if len(path) == 0:
2783
+ return None
2784
+
2785
+ # if root is overshadowed by local symbols, bail out
2786
+ if path[0] in adj.symbols:
2787
+ return None
2788
+
2789
+ # look up in closure/global variables
2790
+ expr = adj.resolve_external_reference(path[0])
2791
+
2792
+ # Support Warp types in kernels without the module suffix (e.g. v = vec3(0.0,0.2,0.4)):
2793
+ if expr is None:
2794
+ expr = getattr(warp, path[0], None)
2795
+
2796
+ # look up in builtins
2797
+ if expr is None:
2798
+ expr = __builtins__.get(path[0])
2799
+
2800
+ if expr is not None:
2801
+ for i in range(1, len(path)):
2802
+ if hasattr(expr, path[i]):
2803
+ expr = getattr(expr, path[i])
2804
+
2805
+ return expr
2806
+
2807
+ # retrieves a dictionary of all closure and global variables and their values
2808
+ # to be used in the evaluation context of wp.static() expressions
2809
+ def get_static_evaluation_context(adj):
2810
+ closure_vars = dict(
2811
+ zip(
2812
+ adj.func.__code__.co_freevars,
2813
+ [c.cell_contents for c in (adj.func.__closure__ or [])],
2814
+ )
2815
+ )
2816
+
2817
+ vars_dict = {}
2818
+ vars_dict.update(adj.func.__globals__)
2819
+ # variables captured in closure have precedence over global vars
2820
+ vars_dict.update(closure_vars)
2821
+
2822
+ return vars_dict
2823
+
2824
+ def is_static_expression(adj, func):
2825
+ return (
2826
+ isinstance(func, types.FunctionType)
2827
+ and func.__module__ == "warp.builtins"
2828
+ and func.__qualname__ == "static"
2829
+ )
2830
+
2831
+ # verify the return type of a wp.static() expression is supported inside a Warp kernel
2832
+ def verify_static_return_value(adj, value):
2833
+ if value is None:
2834
+ raise ValueError("None is returned")
2835
+ if warp.types.is_value(value):
2836
+ return True
2837
+ if warp.types.is_array(value):
2838
+ # more useful explanation for the common case of creating a Warp array
2839
+ raise ValueError("a Warp array cannot be created inside Warp kernels")
2840
+ if isinstance(value, str):
2841
+ # we want to support cases such as `print(wp.static("test"))`
2842
+ return True
2843
+ if isinstance(value, warp.context.Function):
2844
+ return True
2845
+
2846
+ def verify_struct(s: StructInstance, attr_path: List[str]):
2847
+ for key in s._cls.vars.keys():
2848
+ v = getattr(s, key)
2849
+ if issubclass(type(v), StructInstance):
2850
+ verify_struct(v, attr_path + [key])
2851
+ else:
2852
+ try:
2853
+ adj.verify_static_return_value(v)
2854
+ except ValueError as e:
2855
+ raise ValueError(
2856
+ f"the returned Warp struct contains a data type that cannot be constructed inside Warp kernels: {e} at {value._cls.key}.{'.'.join(attr_path)}"
2857
+ ) from e
2858
+
2859
+ if issubclass(type(value), StructInstance):
2860
+ return verify_struct(value, [])
2861
+
2862
+ raise ValueError(f"value of type {type(value)} cannot be constructed inside Warp kernels")
2863
+
2864
+ # find the source code string of an AST node
2865
+ def extract_node_source(adj, node) -> Optional[str]:
2866
+ if not hasattr(node, "lineno") or not hasattr(node, "col_offset"):
2867
+ return None
2868
+
2869
+ start_line = node.lineno - 1 # line numbers start at 1
2870
+ start_col = node.col_offset
2871
+
2872
+ if hasattr(node, "end_lineno") and hasattr(node, "end_col_offset"):
2873
+ end_line = node.end_lineno - 1
2874
+ end_col = node.end_col_offset
2875
+ else:
2876
+ # fallback for Python versions before 3.8
2877
+ # we have to find the end line and column manually
2878
+ end_line = start_line
2879
+ end_col = start_col
2880
+ parenthesis_count = 1
2881
+ for lineno in range(start_line, len(adj.source_lines)):
2882
+ if lineno == start_line:
2883
+ c_start = start_col
2884
+ else:
2885
+ c_start = 0
2886
+ line = adj.source_lines[lineno]
2887
+ for i in range(c_start, len(line)):
2888
+ c = line[i]
2889
+ if c == "(":
2890
+ parenthesis_count += 1
2891
+ elif c == ")":
2892
+ parenthesis_count -= 1
2893
+ if parenthesis_count == 0:
2894
+ end_col = i
2895
+ end_line = lineno
2896
+ break
2897
+ if parenthesis_count == 0:
2898
+ break
2899
+
2900
+ if start_line == end_line:
2901
+ # single-line expression
2902
+ return adj.source_lines[start_line][start_col:end_col]
2903
+ else:
2904
+ # multi-line expression
2905
+ lines = []
2906
+ # first line (from start_col to the end)
2907
+ lines.append(adj.source_lines[start_line][start_col:])
2908
+ # middle lines (entire lines)
2909
+ lines.extend(adj.source_lines[start_line + 1 : end_line])
2910
+ # last line (from the start to end_col)
2911
+ lines.append(adj.source_lines[end_line][:end_col])
2912
+ return "\n".join(lines).strip()
2913
+
2914
+ # handles a wp.static() expression and returns the resulting object and a string representing the code
2915
+ # of the static expression
2916
+ def evaluate_static_expression(adj, node) -> Tuple[Any, str]:
2917
+ if len(node.args) == 1:
2918
+ static_code = adj.extract_node_source(node.args[0])
2919
+ elif len(node.keywords) == 1:
2920
+ static_code = adj.extract_node_source(node.keywords[0])
2921
+ else:
2922
+ raise WarpCodegenError("warp.static() requires a single argument or keyword")
2923
+ if static_code is None:
2924
+ raise WarpCodegenError("Error extracting source code from wp.static() expression")
2925
+
2926
+ # Since this is an expression, we can enforce it to be defined on a single line.
2927
+ static_code = static_code.replace("\n", "")
2928
+
2929
+ vars_dict = adj.get_static_evaluation_context()
2930
+ # add constant variables to the static call context
2931
+ constant_vars = {k: v.constant for k, v in adj.symbols.items() if isinstance(v, Var) and v.constant is not None}
2932
+ vars_dict.update(constant_vars)
2933
+
2934
+ # Replace all constant `len()` expressions with their value.
2935
+ if "len" in static_code:
2936
+
2937
+ def eval_len(obj):
2938
+ if type_is_vector(obj):
2939
+ return obj._length_
2940
+ elif type_is_quaternion(obj):
2941
+ return obj._length_
2942
+ elif type_is_matrix(obj):
2943
+ return obj._shape_[0]
2944
+ elif type_is_transformation(obj):
2945
+ return obj._length_
2946
+ elif is_tile(obj):
2947
+ return obj.shape[0]
2948
+
2949
+ return len(obj)
2950
+
2951
+ len_expr_ctx = vars_dict.copy()
2952
+ constant_types = {k: v.type for k, v in adj.symbols.items() if isinstance(v, Var) and v.type is not None}
2953
+ len_expr_ctx.update(constant_types)
2954
+ len_expr_ctx.update({"len": eval_len})
2955
+
2956
+ # We want to replace the expression code in-place,
2957
+ # so reparse it to get the correct column info.
2958
+ len_value_locs: List[Tuple[int, int, int]] = []
2959
+ expr_tree = ast.parse(static_code)
2960
+ assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
2961
+ expr_root = expr_tree.body[0].value
2962
+ for expr_node in ast.walk(expr_root):
2963
+ if (
2964
+ isinstance(expr_node, ast.Call)
2965
+ and getattr(expr_node.func, "id", None) == "len"
2966
+ and len(expr_node.args) == 1
2967
+ ):
2968
+ len_expr = static_code[expr_node.col_offset : expr_node.end_col_offset]
2969
+ try:
2970
+ len_value = eval(len_expr, len_expr_ctx)
2971
+ except Exception:
2972
+ pass
2973
+ else:
2974
+ len_value_locs.append((len_value, expr_node.col_offset, expr_node.end_col_offset))
2975
+
2976
+ if len_value_locs:
2977
+ new_static_code = ""
2978
+ loc = 0
2979
+ for value, start, end in len_value_locs:
2980
+ new_static_code += f"{static_code[loc:start]}{value}"
2981
+ loc = end
2982
+
2983
+ new_static_code += static_code[len_value_locs[-1][2] :]
2984
+ static_code = new_static_code
2985
+
2986
+ try:
2987
+ value = eval(static_code, vars_dict)
2988
+ if warp.config.verbose:
2989
+ print(f"Evaluated static command: {static_code} = {value}")
2990
+ except NameError as e:
2991
+ raise WarpCodegenError(
2992
+ f"Error evaluating static expression: {e}. Make sure all variables used in the static expression are constant."
2993
+ ) from e
2994
+ except Exception as e:
2995
+ raise WarpCodegenError(
2996
+ f"Error evaluating static expression: {e} while evaluating the following code generated from the static expression:\n{static_code}"
2997
+ ) from e
2998
+
2999
+ try:
3000
+ adj.verify_static_return_value(value)
3001
+ except ValueError as e:
3002
+ raise WarpCodegenError(
3003
+ f"Static expression returns an unsupported value: {e} while evaluating the following code generated from the static expression:\n{static_code}"
3004
+ ) from e
3005
+
3006
+ return value, static_code
3007
+
3008
+ # try to replace wp.static() expressions by their evaluated value if the
3009
+ # expression can be evaluated
3010
+ def replace_static_expressions(adj):
3011
+ class StaticExpressionReplacer(ast.NodeTransformer):
3012
+ def visit_Call(self, node):
3013
+ func, _ = adj.resolve_static_expression(node.func, eval_types=False)
3014
+ if adj.is_static_expression(func):
3015
+ try:
3016
+ # the static expression will execute as long as the static expression is valid and
3017
+ # only depends on global or captured variables
3018
+ obj, code = adj.evaluate_static_expression(node)
3019
+ if code is not None:
3020
+ adj.static_expressions[code] = obj
3021
+ if isinstance(obj, warp.context.Function):
3022
+ name_node = ast.Name("__warp_func__")
3023
+ # we add a pointer to the Warp function here so that we can refer to it later at
3024
+ # codegen time (note that the function key itself is not sufficient to uniquely
3025
+ # identify the function, as the function may be redefined between the current time
3026
+ # of wp.static() declaration and the time of codegen during module building)
3027
+ name_node.warp_func = obj
3028
+ return ast.copy_location(name_node, node)
3029
+ else:
3030
+ return ast.copy_location(ast.Constant(value=obj), node)
3031
+ except Exception:
3032
+ # Ignoring failing static expressions should generally not be an issue because only
3033
+ # one of these cases should be possible:
3034
+ # 1) the static expression itself is invalid code, in which case the module cannot be
3035
+ # built all,
3036
+ # 2) the static expression contains a reference to a local (even if constant) variable
3037
+ # (and is therefore not executable and raises this exception), in which
3038
+ # case changing the constant, or the code affecting this constant, would lead to
3039
+ # a different module hash anyway.
3040
+ pass
3041
+
3042
+ return self.generic_visit(node)
3043
+
3044
+ adj.tree = StaticExpressionReplacer().visit(adj.tree)
3045
+
3046
+ # Evaluates a static expression that does not depend on runtime values
3047
+ # if eval_types is True, try resolving the path using evaluated type information as well
3048
+ def resolve_static_expression(adj, root_node, eval_types=True):
3049
+ attributes = []
3050
+
3051
+ node = root_node
3052
+ while isinstance(node, ast.Attribute):
3053
+ attributes.append(node.attr)
3054
+ node = node.value
3055
+
3056
+ if eval_types and isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
3057
+ # support for operators returning modules
3058
+ # i.e. operator_name(*operator_args).x.y.z
3059
+ operator_args = node.args
3060
+ operator_name = node.func.id
3061
+
3062
+ if operator_name == "type":
3063
+ if len(operator_args) != 1:
3064
+ raise WarpCodegenError(f"type() operator expects exactly one argument, got {len(operator_args)}")
3065
+
3066
+ # type() operator
3067
+ var = adj.eval(operator_args[0])
3068
+
3069
+ if isinstance(var, Var):
3070
+ var_type = strip_reference(var.type)
3071
+ # Allow accessing type attributes, for instance array.dtype
3072
+ while attributes:
3073
+ attr_name = attributes.pop()
3074
+ var_type, prev_type = adj.resolve_type_attribute(var_type, attr_name), var_type
3075
+
3076
+ if var_type is None:
3077
+ raise WarpCodegenAttributeError(
3078
+ f"{attr_name} is not an attribute of {type_repr(prev_type)}"
3079
+ )
3080
+
3081
+ return var_type, [str(var_type)]
3082
+ else:
3083
+ raise WarpCodegenError(f"Cannot deduce the type of {var}")
3084
+
3085
+ # reverse list since ast presents it in backward order
3086
+ path = [*reversed(attributes)]
3087
+ if isinstance(node, ast.Name):
3088
+ path.insert(0, node.id)
3089
+
3090
+ # Try resolving path from captured context
3091
+ captured_obj = adj.resolve_path(path)
3092
+ if captured_obj is not None:
3093
+ return captured_obj, path
3094
+
3095
+ return None, path
3096
+
3097
+ def resolve_external_reference(adj, name: str):
3098
+ try:
3099
+ # look up in closure variables
3100
+ idx = adj.func.__code__.co_freevars.index(name)
3101
+ obj = adj.func.__closure__[idx].cell_contents
3102
+ except ValueError:
3103
+ # look up in global variables
3104
+ obj = adj.func.__globals__.get(name)
3105
+ return obj
3106
+
3107
+ # annotate generated code with the original source code line
3108
+ def set_lineno(adj, lineno):
3109
+ if adj.lineno is None or adj.lineno != lineno:
3110
+ line = lineno + adj.fun_lineno
3111
+ source = adj.source_lines[lineno].strip().ljust(80 - len(adj.indentation), " ")
3112
+ adj.add_forward(f"// {source} <L {line}>")
3113
+ adj.add_reverse(f"// adj: {source} <L {line}>")
3114
+ adj.lineno = lineno
3115
+
3116
+ def get_node_source(adj, node):
3117
+ # return the Python code corresponding to the given AST node
3118
+ return ast.get_source_segment(adj.source, node)
3119
+
3120
+ def get_references(adj) -> Tuple[Dict[str, Any], Dict[Any, Any], Dict[warp.context.Function, Any]]:
3121
+ """Traverses ``adj.tree`` and returns referenced constants, types, and user-defined functions."""
3122
+
3123
+ local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed
3124
+
3125
+ constants: Dict[str, Any] = {}
3126
+ types: Dict[Union[Struct, type], Any] = {}
3127
+ functions: Dict[warp.context.Function, Any] = {}
3128
+
3129
+ for node in ast.walk(adj.tree):
3130
+ if isinstance(node, ast.Name) and node.id not in local_variables:
3131
+ # look up in closure/global variables
3132
+ obj = adj.resolve_external_reference(node.id)
3133
+ if warp.types.is_value(obj):
3134
+ constants[node.id] = obj
3135
+
3136
+ elif isinstance(node, ast.Attribute):
3137
+ obj, path = adj.resolve_static_expression(node, eval_types=False)
3138
+ if warp.types.is_value(obj):
3139
+ constants[".".join(path)] = obj
3140
+
3141
+ elif isinstance(node, ast.Call):
3142
+ func, _ = adj.resolve_static_expression(node.func, eval_types=False)
3143
+ if isinstance(func, warp.context.Function) and not func.is_builtin():
3144
+ # calling user-defined function
3145
+ functions[func] = None
3146
+ elif isinstance(func, Struct):
3147
+ # calling struct constructor
3148
+ types[func] = None
3149
+ elif isinstance(func, type) and warp.types.type_is_value(func):
3150
+ # calling value type constructor
3151
+ types[func] = None
3152
+
3153
+ elif isinstance(node, ast.Assign):
3154
+ # Add the LHS names to the local_variables so we know any subsequent uses are shadowed
3155
+ lhs = node.targets[0]
3156
+ if isinstance(lhs, ast.Tuple):
3157
+ for v in lhs.elts:
3158
+ if isinstance(v, ast.Name):
3159
+ local_variables.add(v.id)
3160
+ elif isinstance(lhs, ast.Name):
3161
+ local_variables.add(lhs.id)
3162
+
3163
+ return constants, types, functions
3164
+
3165
+
3166
+ # ----------------
3167
+ # code generation
3168
+
3169
+ cpu_module_header = """
3170
+ #define WP_TILE_BLOCK_DIM {block_dim}
3171
+ #define WP_NO_CRT
3172
+ #include "builtin.h"
3173
+
3174
+ // avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
3175
+ #define float(x) cast_float(x)
3176
+ #define adj_float(x, adj_x, adj_ret) adj_cast_float(x, adj_x, adj_ret)
3177
+
3178
+ #define int(x) cast_int(x)
3179
+ #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
3180
+
3181
+ #define builtin_tid1d() wp::tid(task_index, dim)
3182
+ #define builtin_tid2d(x, y) wp::tid(x, y, task_index, dim)
3183
+ #define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
3184
+ #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
3185
+
3186
+ """
3187
+
3188
+ cuda_module_header = """
3189
+ #define WP_TILE_BLOCK_DIM {block_dim}
3190
+ #define WP_NO_CRT
3191
+ #include "builtin.h"
3192
+
3193
+ // avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
3194
+ #define float(x) cast_float(x)
3195
+ #define adj_float(x, adj_x, adj_ret) adj_cast_float(x, adj_x, adj_ret)
3196
+
3197
+ #define int(x) cast_int(x)
3198
+ #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
3199
+
3200
+ #define builtin_tid1d() wp::tid(_idx, dim)
3201
+ #define builtin_tid2d(x, y) wp::tid(x, y, _idx, dim)
3202
+ #define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
3203
+ #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
3204
+
3205
+ """
3206
+
3207
+ struct_template = """
3208
+ struct {name}
3209
+ {{
3210
+ {struct_body}
3211
+
3212
+ {defaulted_constructor_def}
3213
+ CUDA_CALLABLE {name}({forward_args})
3214
+ {forward_initializers}
3215
+ {{
3216
+ }}
3217
+
3218
+ CUDA_CALLABLE {name}& operator += (const {name}& rhs)
3219
+ {{{prefix_add_body}
3220
+ return *this;}}
3221
+
3222
+ }};
3223
+
3224
+ static CUDA_CALLABLE void adj_{name}({reverse_args})
3225
+ {{
3226
+ {reverse_body}}}
3227
+
3228
+ CUDA_CALLABLE void adj_atomic_add({name}* p, {name} t)
3229
+ {{
3230
+ {atomic_add_body}}}
3231
+
3232
+
3233
+ """
3234
+
3235
+ cpu_forward_function_template = """
3236
+ // {filename}:{lineno}
3237
+ static {return_type} {name}(
3238
+ {forward_args})
3239
+ {{
3240
+ {forward_body}}}
3241
+
3242
+ """
3243
+
3244
+ cpu_reverse_function_template = """
3245
+ // {filename}:{lineno}
3246
+ static void adj_{name}(
3247
+ {reverse_args})
3248
+ {{
3249
+ {reverse_body}}}
3250
+
3251
+ """
3252
+
3253
+ cuda_forward_function_template = """
3254
+ // {filename}:{lineno}
3255
+ {line_directive}static CUDA_CALLABLE {return_type} {name}(
3256
+ {forward_args})
3257
+ {{
3258
+ {forward_body}{line_directive}}}
3259
+
3260
+ """
3261
+
3262
+ cuda_reverse_function_template = """
3263
+ // {filename}:{lineno}
3264
+ {line_directive}static CUDA_CALLABLE void adj_{name}(
3265
+ {reverse_args})
3266
+ {{
3267
+ {reverse_body}{line_directive}}}
3268
+
3269
+ """
3270
+
3271
+ cuda_kernel_template_forward = """
3272
+
3273
+ {line_directive}extern "C" __global__ void {name}_cuda_kernel_forward(
3274
+ {forward_args})
3275
+ {{
3276
+ {line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3277
+ {line_directive} _idx < dim.size;
3278
+ {line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3279
+ {{
3280
+ // reset shared memory allocator
3281
+ {line_directive} wp::tile_alloc_shared(0, true);
3282
+
3283
+ {forward_body}{line_directive} }}
3284
+ {line_directive}}}
3285
+
3286
+ """
3287
+
3288
+ cuda_kernel_template_backward = """
3289
+
3290
+ {line_directive}extern "C" __global__ void {name}_cuda_kernel_backward(
3291
+ {reverse_args})
3292
+ {{
3293
+ {line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3294
+ {line_directive} _idx < dim.size;
3295
+ {line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3296
+ {{
3297
+ // reset shared memory allocator
3298
+ {line_directive} wp::tile_alloc_shared(0, true);
3299
+
3300
+ {reverse_body}{line_directive} }}
3301
+ {line_directive}}}
3302
+
3303
+ """
3304
+
3305
+ cpu_kernel_template_forward = """
3306
+
3307
+ void {name}_cpu_kernel_forward(
3308
+ {forward_args})
3309
+ {{
3310
+ {forward_body}}}
3311
+
3312
+ """
3313
+
3314
+ cpu_kernel_template_backward = """
3315
+
3316
+ void {name}_cpu_kernel_backward(
3317
+ {reverse_args})
3318
+ {{
3319
+ {reverse_body}}}
3320
+
3321
+ """
3322
+
3323
+ cpu_module_template_forward = """
3324
+
3325
+ extern "C" {{
3326
+
3327
+ // Python CPU entry points
3328
+ WP_API void {name}_cpu_forward(
3329
+ {forward_args})
3330
+ {{
3331
+ for (size_t task_index = 0; task_index < dim.size; ++task_index)
3332
+ {{
3333
+ // init shared memory allocator
3334
+ wp::tile_alloc_shared(0, true);
3335
+
3336
+ {name}_cpu_kernel_forward(
3337
+ {forward_params});
3338
+
3339
+ // check shared memory allocator
3340
+ wp::tile_alloc_shared(0, false, true);
3341
+
3342
+ }}
3343
+ }}
3344
+
3345
+ }} // extern C
3346
+
3347
+ """
3348
+
3349
+ cpu_module_template_backward = """
3350
+
3351
+ extern "C" {{
3352
+
3353
+ WP_API void {name}_cpu_backward(
3354
+ {reverse_args})
3355
+ {{
3356
+ for (size_t task_index = 0; task_index < dim.size; ++task_index)
3357
+ {{
3358
+ // initialize shared memory allocator
3359
+ wp::tile_alloc_shared(0, true);
3360
+
3361
+ {name}_cpu_kernel_backward(
3362
+ {reverse_params});
3363
+
3364
+ // check shared memory allocator
3365
+ wp::tile_alloc_shared(0, false, true);
3366
+ }}
3367
+ }}
3368
+
3369
+ }} // extern C
3370
+
3371
+ """
3372
+
3373
+
3374
+ # converts a constant Python value to equivalent C-repr
3375
+ def constant_str(value):
3376
+ value_type = type(value)
3377
+
3378
+ if value_type == bool or value_type == builtins.bool:
3379
+ if value:
3380
+ return "true"
3381
+ else:
3382
+ return "false"
3383
+
3384
+ elif value_type == str:
3385
+ # ensure constant strings are correctly escaped
3386
+ return '"' + str(value.encode("unicode-escape").decode()) + '"'
3387
+
3388
+ elif isinstance(value, ctypes.Array):
3389
+ if value_type._wp_scalar_type_ == float16:
3390
+ # special case for float16, which is stored as uint16 in the ctypes.Array
3391
+ from warp.context import runtime
3392
+
3393
+ scalar_value = runtime.core.half_bits_to_float
3394
+ else:
3395
+
3396
+ def scalar_value(x):
3397
+ return x
3398
+
3399
+ # list of scalar initializer values
3400
+ initlist = []
3401
+ for i in range(value._length_):
3402
+ x = ctypes.Array.__getitem__(value, i)
3403
+ initlist.append(str(scalar_value(x)).lower())
3404
+
3405
+ if value._wp_scalar_type_ is bool:
3406
+ dtypestr = f"wp::initializer_array<{value._length_},{value._wp_scalar_type_.__name__}>"
3407
+ else:
3408
+ dtypestr = f"wp::initializer_array<{value._length_},wp::{value._wp_scalar_type_.__name__}>"
3409
+
3410
+ # construct value from initializer array, e.g. wp::initializer_array<4,wp::float32>{1.0, 2.0, 3.0, 4.0}
3411
+ return f"{dtypestr}{{{', '.join(initlist)}}}"
3412
+
3413
+ elif value_type in warp.types.scalar_types:
3414
+ # make sure we emit the value of objects, e.g. uint32
3415
+ return str(value.value)
3416
+
3417
+ elif issubclass(value_type, warp.codegen.StructInstance):
3418
+ # constant struct instance
3419
+ arg_strs = []
3420
+ for key, var in value._cls.vars.items():
3421
+ attr = getattr(value, key)
3422
+ arg_strs.append(f"{Var.type_to_ctype(var.type)}({constant_str(attr)})")
3423
+ arg_str = ", ".join(arg_strs)
3424
+ return f"{value.native_name}({arg_str})"
3425
+
3426
+ elif value == math.inf:
3427
+ return "INFINITY"
3428
+
3429
+ elif math.isnan(value):
3430
+ return "NAN"
3431
+
3432
+ else:
3433
+ # otherwise just convert constant to string
3434
+ return str(value)
3435
+
3436
+
3437
+ def indent(args, stops=1):
3438
+ sep = ",\n"
3439
+ for _i in range(stops):
3440
+ sep += " "
3441
+
3442
+ # return sep + args.replace(", ", "," + sep)
3443
+ return sep.join(args)
3444
+
3445
+
3446
+ # generates a C function name based on the python function name
3447
+ def make_full_qualified_name(func: Union[str, Callable]) -> str:
3448
+ if not isinstance(func, str):
3449
+ func = func.__qualname__
3450
+ return re.sub("[^0-9a-zA-Z_]+", "", func.replace(".", "__"))
3451
+
3452
+
3453
+ def codegen_struct(struct, device="cpu", indent_size=4):
3454
+ name = struct.native_name
3455
+
3456
+ body = []
3457
+ indent_block = " " * indent_size
3458
+
3459
+ if len(struct.vars) > 0:
3460
+ for label, var in struct.vars.items():
3461
+ body.append(var.ctype() + " " + label + ";\n")
3462
+ else:
3463
+ # for empty structs, emit the dummy attribute to avoid any compiler-specific alignment issues
3464
+ body.append("char _dummy_;\n")
3465
+
3466
+ forward_args = []
3467
+ reverse_args = []
3468
+
3469
+ forward_initializers = []
3470
+ reverse_body = []
3471
+ atomic_add_body = []
3472
+ prefix_add_body = []
3473
+
3474
+ # forward args
3475
+ for label, var in struct.vars.items():
3476
+ var_ctype = var.ctype()
3477
+ default_arg_def = " = {}" if forward_args else ""
3478
+ forward_args.append(f"{var_ctype} const& {label}{default_arg_def}")
3479
+ reverse_args.append(f"{var_ctype} const&")
3480
+
3481
+ namespace = "wp::" if var_ctype.startswith("wp::") or var_ctype == "bool" else ""
3482
+ atomic_add_body.append(f"{indent_block}{namespace}adj_atomic_add(&p->{label}, t.{label});\n")
3483
+
3484
+ prefix = f"{indent_block}," if forward_initializers else ":"
3485
+ forward_initializers.append(f"{indent_block}{prefix} {label}{{{label}}}\n")
3486
+
3487
+ # prefix-add operator
3488
+ for label, var in struct.vars.items():
3489
+ if not is_array(var.type):
3490
+ prefix_add_body.append(f"{indent_block}{label} += rhs.{label};\n")
3491
+
3492
+ # reverse args
3493
+ for label, var in struct.vars.items():
3494
+ reverse_args.append(var.ctype() + " & adj_" + label)
3495
+ if is_array(var.type):
3496
+ reverse_body.append(f"{indent_block}adj_{label} = adj_ret.{label};\n")
3497
+ else:
3498
+ reverse_body.append(f"{indent_block}adj_{label} += adj_ret.{label};\n")
3499
+
3500
+ reverse_args.append(name + " & adj_ret")
3501
+
3502
+ # explicitly defaulted default constructor if no default constructor has been defined
3503
+ defaulted_constructor_def = f"{name}() = default;" if forward_args else ""
3504
+
3505
+ return struct_template.format(
3506
+ name=name,
3507
+ struct_body="".join([indent_block + l for l in body]),
3508
+ forward_args=indent(forward_args),
3509
+ forward_initializers="".join(forward_initializers),
3510
+ reverse_args=indent(reverse_args),
3511
+ reverse_body="".join(reverse_body),
3512
+ prefix_add_body="".join(prefix_add_body),
3513
+ atomic_add_body="".join(atomic_add_body),
3514
+ defaulted_constructor_def=defaulted_constructor_def,
3515
+ )
3516
+
3517
+
3518
+ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
3519
+ if device == "cpu":
3520
+ indent = 4
3521
+ elif device == "cuda":
3522
+ if func_type == "kernel":
3523
+ indent = 8
3524
+ else:
3525
+ indent = 4
3526
+ else:
3527
+ raise ValueError(f"Device {device} not supported for codegen")
3528
+
3529
+ indent_block = " " * indent
3530
+
3531
+ # primal vars
3532
+ lines = []
3533
+ lines += ["//---------\n"]
3534
+ lines += ["// primal vars\n"]
3535
+
3536
+ for var in adj.variables:
3537
+ if is_tile(var.type):
3538
+ lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=False)};\n"]
3539
+ elif var.constant is None:
3540
+ lines += [f"{var.ctype()} {var.emit()};\n"]
3541
+ else:
3542
+ lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
3543
+
3544
+ if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
3545
+ lines.insert(-1, f"{line_directive}\n")
3546
+
3547
+ # forward pass
3548
+ lines += ["//---------\n"]
3549
+ lines += ["// forward\n"]
3550
+
3551
+ for f in adj.blocks[0].body_forward:
3552
+ lines += [f + "\n"]
3553
+
3554
+ return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
3555
+
3556
+
3557
+ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3558
+ if device == "cpu":
3559
+ indent = 4
3560
+ elif device == "cuda":
3561
+ if func_type == "kernel":
3562
+ indent = 8
3563
+ else:
3564
+ indent = 4
3565
+ else:
3566
+ raise ValueError(f"Device {device} not supported for codegen")
3567
+
3568
+ indent_block = " " * indent
3569
+
3570
+ lines = []
3571
+
3572
+ # primal vars
3573
+ lines += ["//---------\n"]
3574
+ lines += ["// primal vars\n"]
3575
+
3576
+ for var in adj.variables:
3577
+ if is_tile(var.type):
3578
+ lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=True)};\n"]
3579
+ elif var.constant is None:
3580
+ lines += [f"{var.ctype()} {var.emit()};\n"]
3581
+ else:
3582
+ lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
3583
+
3584
+ if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
3585
+ lines.insert(-1, f"{line_directive}\n")
3586
+
3587
+ # dual vars
3588
+ lines += ["//---------\n"]
3589
+ lines += ["// dual vars\n"]
3590
+
3591
+ for var in adj.variables:
3592
+ name = var.emit_adj()
3593
+ ctype = var.ctype(value_type=True)
3594
+
3595
+ if is_tile(var.type):
3596
+ if var.type.storage == "register":
3597
+ lines += [
3598
+ f"{var.type.ctype()} {name}(0.0);\n"
3599
+ ] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
3600
+ elif var.type.storage == "shared":
3601
+ lines += [
3602
+ f"{var.type.ctype()}& {name} = {var.emit()};\n"
3603
+ ] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
3604
+ else:
3605
+ lines += [f"{ctype} {name} = {{}};\n"]
3606
+
3607
+ if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
3608
+ lines.insert(-1, f"{line_directive}\n")
3609
+
3610
+ # forward pass
3611
+ lines += ["//---------\n"]
3612
+ lines += ["// forward\n"]
3613
+
3614
+ for f in adj.blocks[0].body_replay:
3615
+ lines += [f + "\n"]
3616
+
3617
+ # reverse pass
3618
+ lines += ["//---------\n"]
3619
+ lines += ["// reverse\n"]
3620
+
3621
+ for l in reversed(adj.blocks[0].body_reverse):
3622
+ lines += [l + "\n"]
3623
+
3624
+ # In grid-stride kernels the reverse body is in a for loop
3625
+ if device == "cuda" and func_type == "kernel":
3626
+ lines += ["continue;\n"]
3627
+ else:
3628
+ lines += ["return;\n"]
3629
+
3630
+ return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
3631
+
3632
+
3633
+ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3634
+ if options is None:
3635
+ options = {}
3636
+
3637
+ if adj.return_var is not None and "return" in adj.arg_types:
3638
+ if get_origin(adj.arg_types["return"]) is tuple:
3639
+ if len(get_args(adj.arg_types["return"])) != len(adj.return_var):
3640
+ raise WarpCodegenError(
3641
+ f"The function `{adj.fun_name}` has its return type "
3642
+ f"annotated as a tuple of {len(get_args(adj.arg_types['return']))} elements "
3643
+ f"but the code returns {len(adj.return_var)} values."
3644
+ )
3645
+ elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var)):
3646
+ raise WarpCodegenError(
3647
+ f"The function `{adj.fun_name}` has its return type "
3648
+ f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
3649
+ f"but the code returns a tuple with types `({', '.join(warp.context.type_str(x.type) for x in adj.return_var)})`."
3650
+ )
3651
+ elif len(adj.return_var) > 1 and get_origin(adj.arg_types["return"]) is not tuple:
3652
+ raise WarpCodegenError(
3653
+ f"The function `{adj.fun_name}` has its return type "
3654
+ f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
3655
+ f"but the code returns {len(adj.return_var)} values."
3656
+ )
3657
+ elif not types_equal(adj.arg_types["return"], adj.return_var[0].type):
3658
+ raise WarpCodegenError(
3659
+ f"The function `{adj.fun_name}` has its return type "
3660
+ f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
3661
+ f"but the code returns a value of type `{warp.context.type_str(adj.return_var[0].type)}`."
3662
+ )
3663
+
3664
+ # Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
3665
+ # This is used as a catch-all C-to-Python source line mapping for any code that does not have
3666
+ # a direct mapping to a Python source line.
3667
+ func_line_directive = ""
3668
+ if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
3669
+ func_line_directive = f"{line_directive}\n"
3670
+
3671
+ # forward header
3672
+ if adj.return_var is not None and len(adj.return_var) == 1:
3673
+ return_type = adj.return_var[0].ctype()
3674
+ else:
3675
+ return_type = "void"
3676
+
3677
+ has_multiple_outputs = adj.return_var is not None and len(adj.return_var) != 1
3678
+
3679
+ forward_args = []
3680
+ reverse_args = []
3681
+
3682
+ # forward args
3683
+ for i, arg in enumerate(adj.args):
3684
+ s = f"{arg.ctype()} {arg.emit()}"
3685
+ forward_args.append(s)
3686
+ if not adj.custom_reverse_mode or i < adj.custom_reverse_num_input_args:
3687
+ reverse_args.append(s)
3688
+ if has_multiple_outputs:
3689
+ for i, arg in enumerate(adj.return_var):
3690
+ forward_args.append(arg.ctype() + " & ret_" + str(i))
3691
+ reverse_args.append(arg.ctype() + " & ret_" + str(i))
3692
+
3693
+ # reverse args
3694
+ for i, arg in enumerate(adj.args):
3695
+ if adj.custom_reverse_mode and i >= adj.custom_reverse_num_input_args:
3696
+ break
3697
+ # indexed array gradients are regular arrays
3698
+ if isinstance(arg.type, indexedarray):
3699
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
3700
+ reverse_args.append(_arg.ctype() + " & adj_" + arg.label)
3701
+ else:
3702
+ reverse_args.append(arg.ctype() + " & adj_" + arg.label)
3703
+ if has_multiple_outputs:
3704
+ for i, arg in enumerate(adj.return_var):
3705
+ reverse_args.append(arg.ctype() + " & adj_ret_" + str(i))
3706
+ elif return_type != "void":
3707
+ reverse_args.append(return_type + " & adj_ret")
3708
+ # custom output reverse args (user-declared)
3709
+ if adj.custom_reverse_mode:
3710
+ for arg in adj.args[adj.custom_reverse_num_input_args :]:
3711
+ reverse_args.append(f"{arg.ctype()} & {arg.emit()}")
3712
+
3713
+ if device == "cpu":
3714
+ forward_template = cpu_forward_function_template
3715
+ reverse_template = cpu_reverse_function_template
3716
+ elif device == "cuda":
3717
+ forward_template = cuda_forward_function_template
3718
+ reverse_template = cuda_reverse_function_template
3719
+ else:
3720
+ raise ValueError(f"Device {device} is not supported")
3721
+
3722
+ # codegen body
3723
+ forward_body = codegen_func_forward(adj, func_type="function", device=device)
3724
+
3725
+ s = ""
3726
+ if not adj.skip_forward_codegen:
3727
+ s += forward_template.format(
3728
+ name=c_func_name,
3729
+ return_type=return_type,
3730
+ forward_args=indent(forward_args),
3731
+ forward_body=forward_body,
3732
+ filename=adj.filename,
3733
+ lineno=adj.fun_lineno,
3734
+ line_directive=func_line_directive,
3735
+ )
3736
+
3737
+ if not adj.skip_reverse_codegen:
3738
+ if adj.custom_reverse_mode:
3739
+ reverse_body = "\t// user-defined adjoint code\n" + forward_body
3740
+ else:
3741
+ if options.get("enable_backward", True):
3742
+ reverse_body = codegen_func_reverse(adj, func_type="function", device=device)
3743
+ else:
3744
+ reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False)\n'
3745
+ s += reverse_template.format(
3746
+ name=c_func_name,
3747
+ return_type=return_type,
3748
+ reverse_args=indent(reverse_args),
3749
+ forward_body=forward_body,
3750
+ reverse_body=reverse_body,
3751
+ filename=adj.filename,
3752
+ lineno=adj.fun_lineno,
3753
+ line_directive=func_line_directive,
3754
+ )
3755
+
3756
+ return s
3757
+
3758
+
3759
+ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
3760
+ if adj.return_var is not None and len(adj.return_var) == 1:
3761
+ return_type = adj.return_var[0].ctype()
3762
+ else:
3763
+ return_type = "void"
3764
+
3765
+ forward_args = []
3766
+ reverse_args = []
3767
+
3768
+ # forward args
3769
+ for _i, arg in enumerate(adj.args):
3770
+ s = f"{arg.ctype()} {arg.emit().replace('var_', '')}"
3771
+ forward_args.append(s)
3772
+ reverse_args.append(s)
3773
+
3774
+ # reverse args
3775
+ for _i, arg in enumerate(adj.args):
3776
+ if isinstance(arg.type, indexedarray):
3777
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
3778
+ reverse_args.append(_arg.ctype() + " & adj_" + arg.label)
3779
+ else:
3780
+ reverse_args.append(arg.ctype() + " & adj_" + arg.label)
3781
+ if return_type != "void":
3782
+ reverse_args.append(return_type + " & adj_ret")
3783
+
3784
+ forward_template = cuda_forward_function_template
3785
+ replay_template = cuda_forward_function_template
3786
+ reverse_template = cuda_reverse_function_template
3787
+
3788
+ s = ""
3789
+ s += forward_template.format(
3790
+ name=name,
3791
+ return_type=return_type,
3792
+ forward_args=indent(forward_args),
3793
+ forward_body=snippet,
3794
+ filename=adj.filename,
3795
+ lineno=adj.fun_lineno,
3796
+ line_directive="",
3797
+ )
3798
+
3799
+ if replay_snippet is not None:
3800
+ s += replay_template.format(
3801
+ name="replay_" + name,
3802
+ return_type=return_type,
3803
+ forward_args=indent(forward_args),
3804
+ forward_body=replay_snippet,
3805
+ filename=adj.filename,
3806
+ lineno=adj.fun_lineno,
3807
+ line_directive="",
3808
+ )
3809
+
3810
+ if adj_snippet:
3811
+ reverse_body = adj_snippet
3812
+ else:
3813
+ reverse_body = ""
3814
+
3815
+ s += reverse_template.format(
3816
+ name=name,
3817
+ return_type=return_type,
3818
+ reverse_args=indent(reverse_args),
3819
+ forward_body=snippet,
3820
+ reverse_body=reverse_body,
3821
+ filename=adj.filename,
3822
+ lineno=adj.fun_lineno,
3823
+ line_directive="",
3824
+ )
3825
+
3826
+ return s
3827
+
3828
+
3829
+ def codegen_kernel(kernel, device, options):
3830
+ # Update the module's options with the ones defined on the kernel, if any.
3831
+ options = dict(options)
3832
+ options.update(kernel.options)
3833
+
3834
+ adj = kernel.adj
3835
+
3836
+ # Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
3837
+ # This is used as a catch-all C-to-Python source line mapping for any code that does not have
3838
+ # a direct mapping to a Python source line.
3839
+ func_line_directive = ""
3840
+ if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
3841
+ func_line_directive = f"{line_directive}\n"
3842
+
3843
+ if device == "cpu":
3844
+ template_forward = cpu_kernel_template_forward
3845
+ template_backward = cpu_kernel_template_backward
3846
+ elif device == "cuda":
3847
+ template_forward = cuda_kernel_template_forward
3848
+ template_backward = cuda_kernel_template_backward
3849
+ else:
3850
+ raise ValueError(f"Device {device} is not supported")
3851
+
3852
+ template = ""
3853
+ template_fmt_args = {
3854
+ "name": kernel.get_mangled_name(),
3855
+ }
3856
+
3857
+ # build forward signature
3858
+ forward_args = ["wp::launch_bounds_t dim"]
3859
+ if device == "cpu":
3860
+ forward_args.append("size_t task_index")
3861
+
3862
+ for arg in adj.args:
3863
+ forward_args.append(arg.ctype() + " var_" + arg.label)
3864
+
3865
+ forward_body = codegen_func_forward(adj, func_type="kernel", device=device)
3866
+ template_fmt_args.update(
3867
+ {
3868
+ "forward_args": indent(forward_args),
3869
+ "forward_body": forward_body,
3870
+ "line_directive": func_line_directive,
3871
+ }
3872
+ )
3873
+ template += template_forward
3874
+
3875
+ if options["enable_backward"]:
3876
+ # build reverse signature
3877
+ reverse_args = ["wp::launch_bounds_t dim"]
3878
+ if device == "cpu":
3879
+ reverse_args.append("size_t task_index")
3880
+
3881
+ for arg in adj.args:
3882
+ reverse_args.append(arg.ctype() + " var_" + arg.label)
3883
+
3884
+ for arg in adj.args:
3885
+ # indexed array gradients are regular arrays
3886
+ if isinstance(arg.type, indexedarray):
3887
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
3888
+ reverse_args.append(_arg.ctype() + " adj_" + arg.label)
3889
+ else:
3890
+ reverse_args.append(arg.ctype() + " adj_" + arg.label)
3891
+
3892
+ reverse_body = codegen_func_reverse(adj, func_type="kernel", device=device)
3893
+ template_fmt_args.update(
3894
+ {
3895
+ "reverse_args": indent(reverse_args),
3896
+ "reverse_body": reverse_body,
3897
+ }
3898
+ )
3899
+ template += template_backward
3900
+
3901
+ s = template.format(**template_fmt_args)
3902
+ return s
3903
+
3904
+
3905
+ def codegen_module(kernel, device, options):
3906
+ if device != "cpu":
3907
+ return ""
3908
+
3909
+ # Update the module's options with the ones defined on the kernel, if any.
3910
+ options = dict(options)
3911
+ options.update(kernel.options)
3912
+
3913
+ adj = kernel.adj
3914
+
3915
+ template = ""
3916
+ template_fmt_args = {
3917
+ "name": kernel.get_mangled_name(),
3918
+ }
3919
+
3920
+ # build forward signature
3921
+ forward_args = ["wp::launch_bounds_t dim"]
3922
+ forward_params = ["dim", "task_index"]
3923
+
3924
+ for arg in adj.args:
3925
+ if hasattr(arg.type, "_wp_generic_type_str_"):
3926
+ # vectors and matrices are passed from Python by pointer
3927
+ forward_args.append(f"const {arg.ctype()}* var_" + arg.label)
3928
+ forward_params.append(f"*var_{arg.label}")
3929
+ else:
3930
+ forward_args.append(f"{arg.ctype()} var_{arg.label}")
3931
+ forward_params.append("var_" + arg.label)
3932
+
3933
+ template_fmt_args.update(
3934
+ {
3935
+ "forward_args": indent(forward_args),
3936
+ "forward_params": indent(forward_params, 3),
3937
+ }
3938
+ )
3939
+ template += cpu_module_template_forward
3940
+
3941
+ if options["enable_backward"]:
3942
+ # build reverse signature
3943
+ reverse_args = [*forward_args]
3944
+ reverse_params = [*forward_params]
3945
+
3946
+ for arg in adj.args:
3947
+ if isinstance(arg.type, indexedarray):
3948
+ # indexed array gradients are regular arrays
3949
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
3950
+ reverse_args.append(f"const {_arg.ctype()} adj_{arg.label}")
3951
+ reverse_params.append(f"adj_{_arg.label}")
3952
+ elif hasattr(arg.type, "_wp_generic_type_str_"):
3953
+ # vectors and matrices are passed from Python by pointer
3954
+ reverse_args.append(f"const {arg.ctype()}* adj_{arg.label}")
3955
+ reverse_params.append(f"*adj_{arg.label}")
3956
+ else:
3957
+ reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
3958
+ reverse_params.append(f"adj_{arg.label}")
3959
+
3960
+ template_fmt_args.update(
3961
+ {
3962
+ "reverse_args": indent(reverse_args),
3963
+ "reverse_params": indent(reverse_params, 3),
3964
+ }
3965
+ )
3966
+ template += cpu_module_template_backward
3967
+
3968
+ s = template.format(**template_fmt_args)
3969
+ return s